1// Copyright 2010 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package tls
6
7import (
8	"bytes"
9	"context"
10	"crypto/ecdsa"
11	"crypto/elliptic"
12	"crypto/rand"
13	"crypto/rsa"
14	"crypto/x509"
15	"crypto/x509/pkix"
16	"encoding/base64"
17	"encoding/hex"
18	"encoding/pem"
19	"errors"
20	"fmt"
21	"internal/byteorder"
22	"io"
23	"math/big"
24	"net"
25	"os"
26	"os/exec"
27	"path/filepath"
28	"reflect"
29	"runtime"
30	"strconv"
31	"strings"
32	"testing"
33	"time"
34)
35
36// Note: see comment in handshake_test.go for details of how the reference
37// tests work.
38
39// opensslInputEvent enumerates possible inputs that can be sent to an `openssl
40// s_client` process.
41type opensslInputEvent int
42
43const (
44	// opensslRenegotiate causes OpenSSL to request a renegotiation of the
45	// connection.
46	opensslRenegotiate opensslInputEvent = iota
47
48	// opensslSendBanner causes OpenSSL to send the contents of
49	// opensslSentinel on the connection.
50	opensslSendSentinel
51
52	// opensslKeyUpdate causes OpenSSL to send a key update message to the
53	// client and request one back.
54	opensslKeyUpdate
55)
56
57const opensslSentinel = "SENTINEL\n"
58
59type opensslInput chan opensslInputEvent
60
61func (i opensslInput) Read(buf []byte) (n int, err error) {
62	for event := range i {
63		switch event {
64		case opensslRenegotiate:
65			return copy(buf, []byte("R\n")), nil
66		case opensslKeyUpdate:
67			return copy(buf, []byte("K\n")), nil
68		case opensslSendSentinel:
69			return copy(buf, []byte(opensslSentinel)), nil
70		default:
71			panic("unknown event")
72		}
73	}
74
75	return 0, io.EOF
76}
77
78// opensslOutputSink is an io.Writer that receives the stdout and stderr from an
79// `openssl` process and sends a value to handshakeComplete or readKeyUpdate
80// when certain messages are seen.
81type opensslOutputSink struct {
82	handshakeComplete chan struct{}
83	readKeyUpdate     chan struct{}
84	all               []byte
85	line              []byte
86}
87
88func newOpensslOutputSink() *opensslOutputSink {
89	return &opensslOutputSink{make(chan struct{}), make(chan struct{}), nil, nil}
90}
91
92// opensslEndOfHandshake is a message that the “openssl s_server” tool will
93// print when a handshake completes if run with “-state”.
94const opensslEndOfHandshake = "SSL_accept:SSLv3/TLS write finished"
95
96// opensslReadKeyUpdate is a message that the “openssl s_server” tool will
97// print when a KeyUpdate message is received if run with “-state”.
98const opensslReadKeyUpdate = "SSL_accept:TLSv1.3 read client key update"
99
100func (o *opensslOutputSink) Write(data []byte) (n int, err error) {
101	o.line = append(o.line, data...)
102	o.all = append(o.all, data...)
103
104	for {
105		line, next, ok := bytes.Cut(o.line, []byte("\n"))
106		if !ok {
107			break
108		}
109
110		if bytes.Equal([]byte(opensslEndOfHandshake), line) {
111			o.handshakeComplete <- struct{}{}
112		}
113		if bytes.Equal([]byte(opensslReadKeyUpdate), line) {
114			o.readKeyUpdate <- struct{}{}
115		}
116		o.line = next
117	}
118
119	return len(data), nil
120}
121
122func (o *opensslOutputSink) String() string {
123	return string(o.all)
124}
125
126// clientTest represents a test of the TLS client handshake against a reference
127// implementation.
128type clientTest struct {
129	// name is a freeform string identifying the test and the file in which
130	// the expected results will be stored.
131	name string
132	// args, if not empty, contains a series of arguments for the
133	// command to run for the reference server.
134	args []string
135	// config, if not nil, contains a custom Config to use for this test.
136	config *Config
137	// cert, if not empty, contains a DER-encoded certificate for the
138	// reference server.
139	cert []byte
140	// key, if not nil, contains either a *rsa.PrivateKey, ed25519.PrivateKey or
141	// *ecdsa.PrivateKey which is the private key for the reference server.
142	key any
143	// extensions, if not nil, contains a list of extension data to be returned
144	// from the ServerHello. The data should be in standard TLS format with
145	// a 2-byte uint16 type, 2-byte data length, followed by the extension data.
146	extensions [][]byte
147	// validate, if not nil, is a function that will be called with the
148	// ConnectionState of the resulting connection. It returns a non-nil
149	// error if the ConnectionState is unacceptable.
150	validate func(ConnectionState) error
151	// numRenegotiations is the number of times that the connection will be
152	// renegotiated.
153	numRenegotiations int
154	// renegotiationExpectedToFail, if not zero, is the number of the
155	// renegotiation attempt that is expected to fail.
156	renegotiationExpectedToFail int
157	// checkRenegotiationError, if not nil, is called with any error
158	// arising from renegotiation. It can map expected errors to nil to
159	// ignore them.
160	checkRenegotiationError func(renegotiationNum int, err error) error
161	// sendKeyUpdate will cause the server to send a KeyUpdate message.
162	sendKeyUpdate bool
163}
164
165var serverCommand = []string{"openssl", "s_server", "-no_ticket", "-num_tickets", "0"}
166
167// connFromCommand starts the reference server process, connects to it and
168// returns a recordingConn for the connection. The stdin return value is an
169// opensslInput for the stdin of the child process. It must be closed before
170// Waiting for child.
171func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin opensslInput, stdout *opensslOutputSink, err error) {
172	cert := testRSACertificate
173	if len(test.cert) > 0 {
174		cert = test.cert
175	}
176	certPath := tempFile(string(cert))
177	defer os.Remove(certPath)
178
179	var key any = testRSAPrivateKey
180	if test.key != nil {
181		key = test.key
182	}
183	derBytes, err := x509.MarshalPKCS8PrivateKey(key)
184	if err != nil {
185		panic(err)
186	}
187
188	var pemOut bytes.Buffer
189	pem.Encode(&pemOut, &pem.Block{Type: "PRIVATE KEY", Bytes: derBytes})
190
191	keyPath := tempFile(pemOut.String())
192	defer os.Remove(keyPath)
193
194	var command []string
195	command = append(command, serverCommand...)
196	command = append(command, test.args...)
197	command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath)
198	// serverPort contains the port that OpenSSL will listen on. OpenSSL
199	// can't take "0" as an argument here so we have to pick a number and
200	// hope that it's not in use on the machine. Since this only occurs
201	// when -update is given and thus when there's a human watching the
202	// test, this isn't too bad.
203	const serverPort = 24323
204	command = append(command, "-accept", strconv.Itoa(serverPort))
205
206	if len(test.extensions) > 0 {
207		var serverInfo bytes.Buffer
208		for _, ext := range test.extensions {
209			pem.Encode(&serverInfo, &pem.Block{
210				Type:  fmt.Sprintf("SERVERINFO FOR EXTENSION %d", byteorder.BeUint16(ext)),
211				Bytes: ext,
212			})
213		}
214		serverInfoPath := tempFile(serverInfo.String())
215		defer os.Remove(serverInfoPath)
216		command = append(command, "-serverinfo", serverInfoPath)
217	}
218
219	if test.numRenegotiations > 0 || test.sendKeyUpdate {
220		found := false
221		for _, flag := range command[1:] {
222			if flag == "-state" {
223				found = true
224				break
225			}
226		}
227
228		if !found {
229			panic("-state flag missing to OpenSSL, you need this if testing renegotiation or KeyUpdate")
230		}
231	}
232
233	cmd := exec.Command(command[0], command[1:]...)
234	stdin = opensslInput(make(chan opensslInputEvent))
235	cmd.Stdin = stdin
236	out := newOpensslOutputSink()
237	cmd.Stdout = out
238	cmd.Stderr = out
239	if err := cmd.Start(); err != nil {
240		return nil, nil, nil, nil, err
241	}
242
243	// OpenSSL does print an "ACCEPT" banner, but it does so *before*
244	// opening the listening socket, so we can't use that to wait until it
245	// has started listening. Thus we are forced to poll until we get a
246	// connection.
247	var tcpConn net.Conn
248	for i := uint(0); i < 5; i++ {
249		tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{
250			IP:   net.IPv4(127, 0, 0, 1),
251			Port: serverPort,
252		})
253		if err == nil {
254			break
255		}
256		time.Sleep((1 << i) * 5 * time.Millisecond)
257	}
258	if err != nil {
259		close(stdin)
260		cmd.Process.Kill()
261		err = fmt.Errorf("error connecting to the OpenSSL server: %v (%v)\n\n%s", err, cmd.Wait(), out)
262		return nil, nil, nil, nil, err
263	}
264
265	record := &recordingConn{
266		Conn: tcpConn,
267	}
268
269	return record, cmd, stdin, out, nil
270}
271
272func (test *clientTest) dataPath() string {
273	return filepath.Join("testdata", "Client-"+test.name)
274}
275
276func (test *clientTest) loadData() (flows [][]byte, err error) {
277	in, err := os.Open(test.dataPath())
278	if err != nil {
279		return nil, err
280	}
281	defer in.Close()
282	return parseTestData(in)
283}
284
285func (test *clientTest) run(t *testing.T, write bool) {
286	var clientConn net.Conn
287	var recordingConn *recordingConn
288	var childProcess *exec.Cmd
289	var stdin opensslInput
290	var stdout *opensslOutputSink
291
292	if write {
293		var err error
294		recordingConn, childProcess, stdin, stdout, err = test.connFromCommand()
295		if err != nil {
296			t.Fatalf("Failed to start subcommand: %s", err)
297		}
298		clientConn = recordingConn
299		defer func() {
300			if t.Failed() {
301				t.Logf("OpenSSL output:\n\n%s", stdout.all)
302			}
303		}()
304	} else {
305		flows, err := test.loadData()
306		if err != nil {
307			t.Fatalf("failed to load data from %s: %v", test.dataPath(), err)
308		}
309		clientConn = &replayingConn{t: t, flows: flows, reading: false}
310	}
311
312	config := test.config
313	if config == nil {
314		config = testConfig
315	}
316	client := Client(clientConn, config)
317	defer client.Close()
318
319	if _, err := client.Write([]byte("hello\n")); err != nil {
320		t.Errorf("Client.Write failed: %s", err)
321		return
322	}
323
324	for i := 1; i <= test.numRenegotiations; i++ {
325		// The initial handshake will generate a
326		// handshakeComplete signal which needs to be quashed.
327		if i == 1 && write {
328			<-stdout.handshakeComplete
329		}
330
331		// OpenSSL will try to interleave application data and
332		// a renegotiation if we send both concurrently.
333		// Therefore: ask OpensSSL to start a renegotiation, run
334		// a goroutine to call client.Read and thus process the
335		// renegotiation request, watch for OpenSSL's stdout to
336		// indicate that the handshake is complete and,
337		// finally, have OpenSSL write something to cause
338		// client.Read to complete.
339		if write {
340			stdin <- opensslRenegotiate
341		}
342
343		signalChan := make(chan struct{})
344
345		go func() {
346			defer close(signalChan)
347
348			buf := make([]byte, 256)
349			n, err := client.Read(buf)
350
351			if test.checkRenegotiationError != nil {
352				newErr := test.checkRenegotiationError(i, err)
353				if err != nil && newErr == nil {
354					return
355				}
356				err = newErr
357			}
358
359			if err != nil {
360				t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err)
361				return
362			}
363
364			buf = buf[:n]
365			if !bytes.Equal([]byte(opensslSentinel), buf) {
366				t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
367			}
368
369			if expected := i + 1; client.handshakes != expected {
370				t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes)
371			}
372		}()
373
374		if write && test.renegotiationExpectedToFail != i {
375			<-stdout.handshakeComplete
376			stdin <- opensslSendSentinel
377		}
378		<-signalChan
379	}
380
381	if test.sendKeyUpdate {
382		if write {
383			<-stdout.handshakeComplete
384			stdin <- opensslKeyUpdate
385		}
386
387		doneRead := make(chan struct{})
388
389		go func() {
390			defer close(doneRead)
391
392			buf := make([]byte, 256)
393			n, err := client.Read(buf)
394
395			if err != nil {
396				t.Errorf("Client.Read failed after KeyUpdate: %s", err)
397				return
398			}
399
400			buf = buf[:n]
401			if !bytes.Equal([]byte(opensslSentinel), buf) {
402				t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
403			}
404		}()
405
406		if write {
407			// There's no real reason to wait for the client KeyUpdate to
408			// send data with the new server keys, except that s_server
409			// drops writes if they are sent at the wrong time.
410			<-stdout.readKeyUpdate
411			stdin <- opensslSendSentinel
412		}
413		<-doneRead
414
415		if _, err := client.Write([]byte("hello again\n")); err != nil {
416			t.Errorf("Client.Write failed: %s", err)
417			return
418		}
419	}
420
421	if test.validate != nil {
422		if err := test.validate(client.ConnectionState()); err != nil {
423			t.Errorf("validate callback returned error: %s", err)
424		}
425	}
426
427	// If the server sent us an alert after our last flight, give it a
428	// chance to arrive.
429	if write && test.renegotiationExpectedToFail == 0 {
430		if err := peekError(client); err != nil {
431			t.Errorf("final Read returned an error: %s", err)
432		}
433	}
434
435	if write {
436		clientConn.Close()
437		path := test.dataPath()
438		out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
439		if err != nil {
440			t.Fatalf("Failed to create output file: %s", err)
441		}
442		defer out.Close()
443		recordingConn.Close()
444		close(stdin)
445		childProcess.Process.Kill()
446		childProcess.Wait()
447		if len(recordingConn.flows) < 3 {
448			t.Fatalf("Client connection didn't work")
449		}
450		recordingConn.WriteTo(out)
451		t.Logf("Wrote %s\n", path)
452	}
453}
454
455// peekError does a read with a short timeout to check if the next read would
456// cause an error, for example if there is an alert waiting on the wire.
457func peekError(conn net.Conn) error {
458	conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
459	if n, err := conn.Read(make([]byte, 1)); n != 0 {
460		return errors.New("unexpectedly read data")
461	} else if err != nil {
462		if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
463			return err
464		}
465	}
466	return nil
467}
468
469func runClientTestForVersion(t *testing.T, template *clientTest, version, option string) {
470	// Make a deep copy of the template before going parallel.
471	test := *template
472	if template.config != nil {
473		test.config = template.config.Clone()
474	}
475	test.name = version + "-" + test.name
476	test.args = append([]string{option}, test.args...)
477
478	runTestAndUpdateIfNeeded(t, version, test.run, false)
479}
480
481func runClientTestTLS10(t *testing.T, template *clientTest) {
482	runClientTestForVersion(t, template, "TLSv10", "-tls1")
483}
484
485func runClientTestTLS11(t *testing.T, template *clientTest) {
486	runClientTestForVersion(t, template, "TLSv11", "-tls1_1")
487}
488
489func runClientTestTLS12(t *testing.T, template *clientTest) {
490	runClientTestForVersion(t, template, "TLSv12", "-tls1_2")
491}
492
493func runClientTestTLS13(t *testing.T, template *clientTest) {
494	runClientTestForVersion(t, template, "TLSv13", "-tls1_3")
495}
496
497func TestHandshakeClientRSARC4(t *testing.T) {
498	test := &clientTest{
499		name: "RSA-RC4",
500		args: []string{"-cipher", "RC4-SHA"},
501	}
502	runClientTestTLS10(t, test)
503	runClientTestTLS11(t, test)
504	runClientTestTLS12(t, test)
505}
506
507func TestHandshakeClientRSAAES128GCM(t *testing.T) {
508	test := &clientTest{
509		name: "AES128-GCM-SHA256",
510		args: []string{"-cipher", "AES128-GCM-SHA256"},
511	}
512	runClientTestTLS12(t, test)
513}
514
515func TestHandshakeClientRSAAES256GCM(t *testing.T) {
516	test := &clientTest{
517		name: "AES256-GCM-SHA384",
518		args: []string{"-cipher", "AES256-GCM-SHA384"},
519	}
520	runClientTestTLS12(t, test)
521}
522
523func TestHandshakeClientECDHERSAAES(t *testing.T) {
524	test := &clientTest{
525		name: "ECDHE-RSA-AES",
526		args: []string{"-cipher", "ECDHE-RSA-AES128-SHA"},
527	}
528	runClientTestTLS10(t, test)
529	runClientTestTLS11(t, test)
530	runClientTestTLS12(t, test)
531}
532
533func TestHandshakeClientECDHEECDSAAES(t *testing.T) {
534	test := &clientTest{
535		name: "ECDHE-ECDSA-AES",
536		args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA"},
537		cert: testECDSACertificate,
538		key:  testECDSAPrivateKey,
539	}
540	runClientTestTLS10(t, test)
541	runClientTestTLS11(t, test)
542	runClientTestTLS12(t, test)
543}
544
545func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) {
546	test := &clientTest{
547		name: "ECDHE-ECDSA-AES-GCM",
548		args: []string{"-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"},
549		cert: testECDSACertificate,
550		key:  testECDSAPrivateKey,
551	}
552	runClientTestTLS12(t, test)
553}
554
555func TestHandshakeClientAES256GCMSHA384(t *testing.T) {
556	test := &clientTest{
557		name: "ECDHE-ECDSA-AES256-GCM-SHA384",
558		args: []string{"-cipher", "ECDHE-ECDSA-AES256-GCM-SHA384"},
559		cert: testECDSACertificate,
560		key:  testECDSAPrivateKey,
561	}
562	runClientTestTLS12(t, test)
563}
564
565func TestHandshakeClientAES128CBCSHA256(t *testing.T) {
566	test := &clientTest{
567		name: "AES128-SHA256",
568		args: []string{"-cipher", "AES128-SHA256"},
569	}
570	runClientTestTLS12(t, test)
571}
572
573func TestHandshakeClientECDHERSAAES128CBCSHA256(t *testing.T) {
574	test := &clientTest{
575		name: "ECDHE-RSA-AES128-SHA256",
576		args: []string{"-cipher", "ECDHE-RSA-AES128-SHA256"},
577	}
578	runClientTestTLS12(t, test)
579}
580
581func TestHandshakeClientECDHEECDSAAES128CBCSHA256(t *testing.T) {
582	test := &clientTest{
583		name: "ECDHE-ECDSA-AES128-SHA256",
584		args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA256"},
585		cert: testECDSACertificate,
586		key:  testECDSAPrivateKey,
587	}
588	runClientTestTLS12(t, test)
589}
590
591func TestHandshakeClientX25519(t *testing.T) {
592	config := testConfig.Clone()
593	config.CurvePreferences = []CurveID{X25519}
594
595	test := &clientTest{
596		name:   "X25519-ECDHE",
597		args:   []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "X25519"},
598		config: config,
599	}
600
601	runClientTestTLS12(t, test)
602	runClientTestTLS13(t, test)
603}
604
605func TestHandshakeClientP256(t *testing.T) {
606	config := testConfig.Clone()
607	config.CurvePreferences = []CurveID{CurveP256}
608
609	test := &clientTest{
610		name:   "P256-ECDHE",
611		args:   []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"},
612		config: config,
613	}
614
615	runClientTestTLS12(t, test)
616	runClientTestTLS13(t, test)
617}
618
619func TestHandshakeClientHelloRetryRequest(t *testing.T) {
620	config := testConfig.Clone()
621	config.CurvePreferences = []CurveID{X25519, CurveP256}
622
623	test := &clientTest{
624		name:   "HelloRetryRequest",
625		args:   []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"},
626		config: config,
627		validate: func(cs ConnectionState) error {
628			if !cs.testingOnlyDidHRR {
629				return errors.New("expected HelloRetryRequest")
630			}
631			return nil
632		},
633	}
634
635	runClientTestTLS13(t, test)
636}
637
638func TestHandshakeClientECDHERSAChaCha20(t *testing.T) {
639	config := testConfig.Clone()
640	config.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305}
641
642	test := &clientTest{
643		name:   "ECDHE-RSA-CHACHA20-POLY1305",
644		args:   []string{"-cipher", "ECDHE-RSA-CHACHA20-POLY1305"},
645		config: config,
646	}
647
648	runClientTestTLS12(t, test)
649}
650
651func TestHandshakeClientECDHEECDSAChaCha20(t *testing.T) {
652	config := testConfig.Clone()
653	config.CipherSuites = []uint16{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305}
654
655	test := &clientTest{
656		name:   "ECDHE-ECDSA-CHACHA20-POLY1305",
657		args:   []string{"-cipher", "ECDHE-ECDSA-CHACHA20-POLY1305"},
658		config: config,
659		cert:   testECDSACertificate,
660		key:    testECDSAPrivateKey,
661	}
662
663	runClientTestTLS12(t, test)
664}
665
666func TestHandshakeClientAES128SHA256(t *testing.T) {
667	test := &clientTest{
668		name: "AES128-SHA256",
669		args: []string{"-ciphersuites", "TLS_AES_128_GCM_SHA256"},
670	}
671	runClientTestTLS13(t, test)
672}
673func TestHandshakeClientAES256SHA384(t *testing.T) {
674	test := &clientTest{
675		name: "AES256-SHA384",
676		args: []string{"-ciphersuites", "TLS_AES_256_GCM_SHA384"},
677	}
678	runClientTestTLS13(t, test)
679}
680func TestHandshakeClientCHACHA20SHA256(t *testing.T) {
681	test := &clientTest{
682		name: "CHACHA20-SHA256",
683		args: []string{"-ciphersuites", "TLS_CHACHA20_POLY1305_SHA256"},
684	}
685	runClientTestTLS13(t, test)
686}
687
688func TestHandshakeClientECDSATLS13(t *testing.T) {
689	test := &clientTest{
690		name: "ECDSA",
691		cert: testECDSACertificate,
692		key:  testECDSAPrivateKey,
693	}
694	runClientTestTLS13(t, test)
695}
696
697func TestHandshakeClientEd25519(t *testing.T) {
698	test := &clientTest{
699		name: "Ed25519",
700		cert: testEd25519Certificate,
701		key:  testEd25519PrivateKey,
702	}
703	runClientTestTLS12(t, test)
704	runClientTestTLS13(t, test)
705
706	config := testConfig.Clone()
707	cert, _ := X509KeyPair([]byte(clientEd25519CertificatePEM), []byte(clientEd25519KeyPEM))
708	config.Certificates = []Certificate{cert}
709
710	test = &clientTest{
711		name:   "ClientCert-Ed25519",
712		args:   []string{"-Verify", "1"},
713		config: config,
714	}
715
716	runClientTestTLS12(t, test)
717	runClientTestTLS13(t, test)
718}
719
720func TestHandshakeClientCertRSA(t *testing.T) {
721	config := testConfig.Clone()
722	cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
723	config.Certificates = []Certificate{cert}
724
725	test := &clientTest{
726		name:   "ClientCert-RSA-RSA",
727		args:   []string{"-cipher", "AES128", "-Verify", "1"},
728		config: config,
729	}
730
731	runClientTestTLS10(t, test)
732	runClientTestTLS12(t, test)
733
734	test = &clientTest{
735		name:   "ClientCert-RSA-ECDSA",
736		args:   []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"},
737		config: config,
738		cert:   testECDSACertificate,
739		key:    testECDSAPrivateKey,
740	}
741
742	runClientTestTLS10(t, test)
743	runClientTestTLS12(t, test)
744	runClientTestTLS13(t, test)
745
746	test = &clientTest{
747		name:   "ClientCert-RSA-AES256-GCM-SHA384",
748		args:   []string{"-cipher", "ECDHE-RSA-AES256-GCM-SHA384", "-Verify", "1"},
749		config: config,
750		cert:   testRSACertificate,
751		key:    testRSAPrivateKey,
752	}
753
754	runClientTestTLS12(t, test)
755}
756
757func TestHandshakeClientCertECDSA(t *testing.T) {
758	config := testConfig.Clone()
759	cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM))
760	config.Certificates = []Certificate{cert}
761
762	test := &clientTest{
763		name:   "ClientCert-ECDSA-RSA",
764		args:   []string{"-cipher", "AES128", "-Verify", "1"},
765		config: config,
766	}
767
768	runClientTestTLS10(t, test)
769	runClientTestTLS12(t, test)
770	runClientTestTLS13(t, test)
771
772	test = &clientTest{
773		name:   "ClientCert-ECDSA-ECDSA",
774		args:   []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"},
775		config: config,
776		cert:   testECDSACertificate,
777		key:    testECDSAPrivateKey,
778	}
779
780	runClientTestTLS10(t, test)
781	runClientTestTLS12(t, test)
782}
783
784// TestHandshakeClientCertRSAPSS tests rsa_pss_rsae_sha256 signatures from both
785// client and server certificates. It also serves from both sides a certificate
786// signed itself with RSA-PSS, mostly to check that crypto/x509 chain validation
787// works.
788func TestHandshakeClientCertRSAPSS(t *testing.T) {
789	cert, err := x509.ParseCertificate(testRSAPSSCertificate)
790	if err != nil {
791		panic(err)
792	}
793	rootCAs := x509.NewCertPool()
794	rootCAs.AddCert(cert)
795
796	config := testConfig.Clone()
797	// Use GetClientCertificate to bypass the client certificate selection logic.
798	config.GetClientCertificate = func(*CertificateRequestInfo) (*Certificate, error) {
799		return &Certificate{
800			Certificate: [][]byte{testRSAPSSCertificate},
801			PrivateKey:  testRSAPrivateKey,
802		}, nil
803	}
804	config.RootCAs = rootCAs
805
806	test := &clientTest{
807		name: "ClientCert-RSA-RSAPSS",
808		args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs",
809			"rsa_pss_rsae_sha256", "-sigalgs", "rsa_pss_rsae_sha256"},
810		config: config,
811		cert:   testRSAPSSCertificate,
812		key:    testRSAPrivateKey,
813	}
814	runClientTestTLS12(t, test)
815	runClientTestTLS13(t, test)
816}
817
818func TestHandshakeClientCertRSAPKCS1v15(t *testing.T) {
819	config := testConfig.Clone()
820	cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
821	config.Certificates = []Certificate{cert}
822
823	test := &clientTest{
824		name: "ClientCert-RSA-RSAPKCS1v15",
825		args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs",
826			"rsa_pkcs1_sha256", "-sigalgs", "rsa_pkcs1_sha256"},
827		config: config,
828	}
829
830	runClientTestTLS12(t, test)
831}
832
833func TestClientKeyUpdate(t *testing.T) {
834	test := &clientTest{
835		name:          "KeyUpdate",
836		args:          []string{"-state"},
837		sendKeyUpdate: true,
838	}
839	runClientTestTLS13(t, test)
840}
841
842func TestResumption(t *testing.T) {
843	t.Run("TLSv12", func(t *testing.T) { testResumption(t, VersionTLS12) })
844	t.Run("TLSv13", func(t *testing.T) { testResumption(t, VersionTLS13) })
845}
846
847func testResumption(t *testing.T, version uint16) {
848	if testing.Short() {
849		t.Skip("skipping in -short mode")
850	}
851	serverConfig := &Config{
852		MaxVersion:   version,
853		CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
854		Certificates: testConfig.Certificates,
855	}
856
857	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
858	if err != nil {
859		panic(err)
860	}
861
862	rootCAs := x509.NewCertPool()
863	rootCAs.AddCert(issuer)
864
865	clientConfig := &Config{
866		MaxVersion:         version,
867		CipherSuites:       []uint16{TLS_RSA_WITH_RC4_128_SHA},
868		ClientSessionCache: NewLRUClientSessionCache(32),
869		RootCAs:            rootCAs,
870		ServerName:         "example.golang",
871	}
872
873	testResumeState := func(test string, didResume bool) {
874		t.Helper()
875		_, hs, err := testHandshake(t, clientConfig, serverConfig)
876		if err != nil {
877			t.Fatalf("%s: handshake failed: %s", test, err)
878		}
879		if hs.DidResume != didResume {
880			t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume)
881		}
882		if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) {
883			t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains)
884		}
885		if got, want := hs.ServerName, clientConfig.ServerName; got != want {
886			t.Errorf("%s: server name %s, want %s", test, got, want)
887		}
888	}
889
890	getTicket := func() []byte {
891		return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.session.ticket
892	}
893	deleteTicket := func() {
894		ticketKey := clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).sessionKey
895		clientConfig.ClientSessionCache.Put(ticketKey, nil)
896	}
897	corruptTicket := func() {
898		clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.session.secret[0] ^= 0xff
899	}
900	randomKey := func() [32]byte {
901		var k [32]byte
902		if _, err := io.ReadFull(serverConfig.rand(), k[:]); err != nil {
903			t.Fatalf("Failed to read new SessionTicketKey: %s", err)
904		}
905		return k
906	}
907
908	testResumeState("Handshake", false)
909	ticket := getTicket()
910	testResumeState("Resume", true)
911	if bytes.Equal(ticket, getTicket()) {
912		t.Fatal("ticket didn't change after resumption")
913	}
914
915	// An old session ticket is replaced with a ticket encrypted with a fresh key.
916	ticket = getTicket()
917	serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) }
918	testResumeState("ResumeWithOldTicket", true)
919	if bytes.Equal(ticket, getTicket()) {
920		t.Fatal("old first ticket matches the fresh one")
921	}
922
923	// Once the session master secret is expired, a full handshake should occur.
924	ticket = getTicket()
925	serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) }
926	testResumeState("ResumeWithExpiredTicket", false)
927	if bytes.Equal(ticket, getTicket()) {
928		t.Fatal("expired first ticket matches the fresh one")
929	}
930
931	serverConfig.Time = func() time.Time { return time.Now() } // reset the time back
932	key1 := randomKey()
933	serverConfig.SetSessionTicketKeys([][32]byte{key1})
934
935	testResumeState("InvalidSessionTicketKey", false)
936	testResumeState("ResumeAfterInvalidSessionTicketKey", true)
937
938	key2 := randomKey()
939	serverConfig.SetSessionTicketKeys([][32]byte{key2, key1})
940	ticket = getTicket()
941	testResumeState("KeyChange", true)
942	if bytes.Equal(ticket, getTicket()) {
943		t.Fatal("new ticket wasn't included while resuming")
944	}
945	testResumeState("KeyChangeFinish", true)
946
947	// Age the session ticket a bit, but not yet expired.
948	serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) }
949	testResumeState("OldSessionTicket", true)
950	ticket = getTicket()
951	// Expire the session ticket, which would force a full handshake.
952	serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) }
953	testResumeState("ExpiredSessionTicket", false)
954	if bytes.Equal(ticket, getTicket()) {
955		t.Fatal("new ticket wasn't provided after old ticket expired")
956	}
957
958	// Age the session ticket a bit at a time, but don't expire it.
959	d := 0 * time.Hour
960	serverConfig.Time = func() time.Time { return time.Now().Add(d) }
961	deleteTicket()
962	testResumeState("GetFreshSessionTicket", false)
963	for i := 0; i < 13; i++ {
964		d += 12 * time.Hour
965		testResumeState("OldSessionTicket", true)
966	}
967	// Expire it (now a little more than 7 days) and make sure a full
968	// handshake occurs for TLS 1.2. Resumption should still occur for
969	// TLS 1.3 since the client should be using a fresh ticket sent over
970	// by the server.
971	d += 12 * time.Hour
972	if version == VersionTLS13 {
973		testResumeState("ExpiredSessionTicket", true)
974	} else {
975		testResumeState("ExpiredSessionTicket", false)
976	}
977	if bytes.Equal(ticket, getTicket()) {
978		t.Fatal("new ticket wasn't provided after old ticket expired")
979	}
980
981	// Reset serverConfig to ensure that calling SetSessionTicketKeys
982	// before the serverConfig is used works.
983	serverConfig = &Config{
984		MaxVersion:   version,
985		CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
986		Certificates: testConfig.Certificates,
987	}
988	serverConfig.SetSessionTicketKeys([][32]byte{key2})
989
990	testResumeState("FreshConfig", true)
991
992	// In TLS 1.3, cross-cipher suite resumption is allowed as long as the KDF
993	// hash matches. Also, Config.CipherSuites does not apply to TLS 1.3.
994	if version != VersionTLS13 {
995		clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA}
996		testResumeState("DifferentCipherSuite", false)
997		testResumeState("DifferentCipherSuiteRecovers", true)
998	}
999
1000	deleteTicket()
1001	testResumeState("WithoutSessionTicket", false)
1002
1003	// In TLS 1.3, HelloRetryRequest is sent after incorrect key share.
1004	// See https://www.rfc-editor.org/rfc/rfc8446#page-14.
1005	if version == VersionTLS13 {
1006		deleteTicket()
1007		serverConfig = &Config{
1008			// Use a different curve than the client to force a HelloRetryRequest.
1009			CurvePreferences: []CurveID{CurveP521, CurveP384, CurveP256},
1010			MaxVersion:       version,
1011			Certificates:     testConfig.Certificates,
1012		}
1013		testResumeState("InitialHandshake", false)
1014		testResumeState("WithHelloRetryRequest", true)
1015
1016		// Reset serverConfig back.
1017		serverConfig = &Config{
1018			MaxVersion:   version,
1019			CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
1020			Certificates: testConfig.Certificates,
1021		}
1022	}
1023
1024	// Session resumption should work when using client certificates
1025	deleteTicket()
1026	serverConfig.ClientCAs = rootCAs
1027	serverConfig.ClientAuth = RequireAndVerifyClientCert
1028	clientConfig.Certificates = serverConfig.Certificates
1029	testResumeState("InitialHandshake", false)
1030	testResumeState("WithClientCertificates", true)
1031	serverConfig.ClientAuth = NoClientCert
1032
1033	// Tickets should be removed from the session cache on TLS handshake
1034	// failure, and the client should recover from a corrupted PSK
1035	testResumeState("FetchTicketToCorrupt", false)
1036	corruptTicket()
1037	_, _, err = testHandshake(t, clientConfig, serverConfig)
1038	if err == nil {
1039		t.Fatalf("handshake did not fail with a corrupted client secret")
1040	}
1041	testResumeState("AfterHandshakeFailure", false)
1042
1043	clientConfig.ClientSessionCache = nil
1044	testResumeState("WithoutSessionCache", false)
1045
1046	clientConfig.ClientSessionCache = &serializingClientCache{t: t}
1047	testResumeState("BeforeSerializingCache", false)
1048	testResumeState("WithSerializingCache", true)
1049}
1050
1051type serializingClientCache struct {
1052	t *testing.T
1053
1054	ticket, state []byte
1055}
1056
1057func (c *serializingClientCache) Get(sessionKey string) (session *ClientSessionState, ok bool) {
1058	if c.ticket == nil {
1059		return nil, false
1060	}
1061	state, err := ParseSessionState(c.state)
1062	if err != nil {
1063		c.t.Error(err)
1064		return nil, false
1065	}
1066	cs, err := NewResumptionState(c.ticket, state)
1067	if err != nil {
1068		c.t.Error(err)
1069		return nil, false
1070	}
1071	return cs, true
1072}
1073
1074func (c *serializingClientCache) Put(sessionKey string, cs *ClientSessionState) {
1075	if cs == nil {
1076		c.ticket, c.state = nil, nil
1077		return
1078	}
1079	ticket, state, err := cs.ResumptionState()
1080	if err != nil {
1081		c.t.Error(err)
1082		return
1083	}
1084	stateBytes, err := state.Bytes()
1085	if err != nil {
1086		c.t.Error(err)
1087		return
1088	}
1089	c.ticket, c.state = ticket, stateBytes
1090}
1091
1092func TestLRUClientSessionCache(t *testing.T) {
1093	// Initialize cache of capacity 4.
1094	cache := NewLRUClientSessionCache(4)
1095	cs := make([]ClientSessionState, 6)
1096	keys := []string{"0", "1", "2", "3", "4", "5", "6"}
1097
1098	// Add 4 entries to the cache and look them up.
1099	for i := 0; i < 4; i++ {
1100		cache.Put(keys[i], &cs[i])
1101	}
1102	for i := 0; i < 4; i++ {
1103		if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
1104			t.Fatalf("session cache failed lookup for added key: %s", keys[i])
1105		}
1106	}
1107
1108	// Add 2 more entries to the cache. First 2 should be evicted.
1109	for i := 4; i < 6; i++ {
1110		cache.Put(keys[i], &cs[i])
1111	}
1112	for i := 0; i < 2; i++ {
1113		if s, ok := cache.Get(keys[i]); ok || s != nil {
1114			t.Fatalf("session cache should have evicted key: %s", keys[i])
1115		}
1116	}
1117
1118	// Touch entry 2. LRU should evict 3 next.
1119	cache.Get(keys[2])
1120	cache.Put(keys[0], &cs[0])
1121	if s, ok := cache.Get(keys[3]); ok || s != nil {
1122		t.Fatalf("session cache should have evicted key 3")
1123	}
1124
1125	// Update entry 0 in place.
1126	cache.Put(keys[0], &cs[3])
1127	if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] {
1128		t.Fatalf("session cache failed update for key 0")
1129	}
1130
1131	// Calling Put with a nil entry deletes the key.
1132	cache.Put(keys[0], nil)
1133	if _, ok := cache.Get(keys[0]); ok {
1134		t.Fatalf("session cache failed to delete key 0")
1135	}
1136
1137	// Delete entry 2. LRU should keep 4 and 5
1138	cache.Put(keys[2], nil)
1139	if _, ok := cache.Get(keys[2]); ok {
1140		t.Fatalf("session cache failed to delete key 4")
1141	}
1142	for i := 4; i < 6; i++ {
1143		if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
1144			t.Fatalf("session cache should not have deleted key: %s", keys[i])
1145		}
1146	}
1147}
1148
1149func TestKeyLogTLS12(t *testing.T) {
1150	var serverBuf, clientBuf bytes.Buffer
1151
1152	clientConfig := testConfig.Clone()
1153	clientConfig.KeyLogWriter = &clientBuf
1154	clientConfig.MaxVersion = VersionTLS12
1155
1156	serverConfig := testConfig.Clone()
1157	serverConfig.KeyLogWriter = &serverBuf
1158	serverConfig.MaxVersion = VersionTLS12
1159
1160	c, s := localPipe(t)
1161	done := make(chan bool)
1162
1163	go func() {
1164		defer close(done)
1165
1166		if err := Server(s, serverConfig).Handshake(); err != nil {
1167			t.Errorf("server: %s", err)
1168			return
1169		}
1170		s.Close()
1171	}()
1172
1173	if err := Client(c, clientConfig).Handshake(); err != nil {
1174		t.Fatalf("client: %s", err)
1175	}
1176
1177	c.Close()
1178	<-done
1179
1180	checkKeylogLine := func(side, loggedLine string) {
1181		if len(loggedLine) == 0 {
1182			t.Fatalf("%s: no keylog line was produced", side)
1183		}
1184		const expectedLen = 13 /* "CLIENT_RANDOM" */ +
1185			1 /* space */ +
1186			32*2 /* hex client nonce */ +
1187			1 /* space */ +
1188			48*2 /* hex master secret */ +
1189			1 /* new line */
1190		if len(loggedLine) != expectedLen {
1191			t.Fatalf("%s: keylog line has incorrect length (want %d, got %d): %q", side, expectedLen, len(loggedLine), loggedLine)
1192		}
1193		if !strings.HasPrefix(loggedLine, "CLIENT_RANDOM "+strings.Repeat("0", 64)+" ") {
1194			t.Fatalf("%s: keylog line has incorrect structure or nonce: %q", side, loggedLine)
1195		}
1196	}
1197
1198	checkKeylogLine("client", clientBuf.String())
1199	checkKeylogLine("server", serverBuf.String())
1200}
1201
1202func TestKeyLogTLS13(t *testing.T) {
1203	var serverBuf, clientBuf bytes.Buffer
1204
1205	clientConfig := testConfig.Clone()
1206	clientConfig.KeyLogWriter = &clientBuf
1207
1208	serverConfig := testConfig.Clone()
1209	serverConfig.KeyLogWriter = &serverBuf
1210
1211	c, s := localPipe(t)
1212	done := make(chan bool)
1213
1214	go func() {
1215		defer close(done)
1216
1217		if err := Server(s, serverConfig).Handshake(); err != nil {
1218			t.Errorf("server: %s", err)
1219			return
1220		}
1221		s.Close()
1222	}()
1223
1224	if err := Client(c, clientConfig).Handshake(); err != nil {
1225		t.Fatalf("client: %s", err)
1226	}
1227
1228	c.Close()
1229	<-done
1230
1231	checkKeylogLines := func(side, loggedLines string) {
1232		loggedLines = strings.TrimSpace(loggedLines)
1233		lines := strings.Split(loggedLines, "\n")
1234		if len(lines) != 4 {
1235			t.Errorf("Expected the %s to log 4 lines, got %d", side, len(lines))
1236		}
1237	}
1238
1239	checkKeylogLines("client", clientBuf.String())
1240	checkKeylogLines("server", serverBuf.String())
1241}
1242
1243func TestHandshakeClientALPNMatch(t *testing.T) {
1244	config := testConfig.Clone()
1245	config.NextProtos = []string{"proto2", "proto1"}
1246
1247	test := &clientTest{
1248		name: "ALPN",
1249		// Note that this needs OpenSSL 1.0.2 because that is the first
1250		// version that supports the -alpn flag.
1251		args:   []string{"-alpn", "proto1,proto2"},
1252		config: config,
1253		validate: func(state ConnectionState) error {
1254			// The server's preferences should override the client.
1255			if state.NegotiatedProtocol != "proto1" {
1256				return fmt.Errorf("Got protocol %q, wanted proto1", state.NegotiatedProtocol)
1257			}
1258			return nil
1259		},
1260	}
1261	runClientTestTLS12(t, test)
1262	runClientTestTLS13(t, test)
1263}
1264
1265func TestServerSelectingUnconfiguredApplicationProtocol(t *testing.T) {
1266	// This checks that the server can't select an application protocol that the
1267	// client didn't offer.
1268
1269	c, s := localPipe(t)
1270	errChan := make(chan error, 1)
1271
1272	go func() {
1273		client := Client(c, &Config{
1274			ServerName:   "foo",
1275			CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1276			NextProtos:   []string{"http", "something-else"},
1277		})
1278		errChan <- client.Handshake()
1279	}()
1280
1281	var header [5]byte
1282	if _, err := io.ReadFull(s, header[:]); err != nil {
1283		t.Fatal(err)
1284	}
1285	recordLen := int(header[3])<<8 | int(header[4])
1286
1287	record := make([]byte, recordLen)
1288	if _, err := io.ReadFull(s, record); err != nil {
1289		t.Fatal(err)
1290	}
1291
1292	serverHello := &serverHelloMsg{
1293		vers:         VersionTLS12,
1294		random:       make([]byte, 32),
1295		cipherSuite:  TLS_RSA_WITH_AES_128_GCM_SHA256,
1296		alpnProtocol: "how-about-this",
1297	}
1298	serverHelloBytes := mustMarshal(t, serverHello)
1299
1300	s.Write([]byte{
1301		byte(recordTypeHandshake),
1302		byte(VersionTLS12 >> 8),
1303		byte(VersionTLS12 & 0xff),
1304		byte(len(serverHelloBytes) >> 8),
1305		byte(len(serverHelloBytes)),
1306	})
1307	s.Write(serverHelloBytes)
1308	s.Close()
1309
1310	if err := <-errChan; !strings.Contains(err.Error(), "server selected unadvertised ALPN protocol") {
1311		t.Fatalf("Expected error about unconfigured cipher suite but got %q", err)
1312	}
1313}
1314
1315// sctsBase64 contains data from `openssl s_client -serverinfo 18 -connect ritter.vg:443`
1316const sctsBase64 = "ABIBaQFnAHUApLkJkLQYWBSHuxOizGdwCjw1mAT5G9+443fNDsgN3BAAAAFHl5nuFgAABAMARjBEAiAcS4JdlW5nW9sElUv2zvQyPoZ6ejKrGGB03gjaBZFMLwIgc1Qbbn+hsH0RvObzhS+XZhr3iuQQJY8S9G85D9KeGPAAdgBo9pj4H2SCvjqM7rkoHUz8cVFdZ5PURNEKZ6y7T0/7xAAAAUeX4bVwAAAEAwBHMEUCIDIhFDgG2HIuADBkGuLobU5a4dlCHoJLliWJ1SYT05z6AiEAjxIoZFFPRNWMGGIjskOTMwXzQ1Wh2e7NxXE1kd1J0QsAdgDuS723dc5guuFCaR+r4Z5mow9+X7By2IMAxHuJeqj9ywAAAUhcZIqHAAAEAwBHMEUCICmJ1rBT09LpkbzxtUC+Hi7nXLR0J+2PmwLp+sJMuqK+AiEAr0NkUnEVKVhAkccIFpYDqHOlZaBsuEhWWrYpg2RtKp0="
1317
1318func TestHandshakClientSCTs(t *testing.T) {
1319	config := testConfig.Clone()
1320
1321	scts, err := base64.StdEncoding.DecodeString(sctsBase64)
1322	if err != nil {
1323		t.Fatal(err)
1324	}
1325
1326	// Note that this needs OpenSSL 1.0.2 because that is the first
1327	// version that supports the -serverinfo flag.
1328	test := &clientTest{
1329		name:       "SCT",
1330		config:     config,
1331		extensions: [][]byte{scts},
1332		validate: func(state ConnectionState) error {
1333			expectedSCTs := [][]byte{
1334				scts[8:125],
1335				scts[127:245],
1336				scts[247:],
1337			}
1338			if n := len(state.SignedCertificateTimestamps); n != len(expectedSCTs) {
1339				return fmt.Errorf("Got %d scts, wanted %d", n, len(expectedSCTs))
1340			}
1341			for i, expected := range expectedSCTs {
1342				if sct := state.SignedCertificateTimestamps[i]; !bytes.Equal(sct, expected) {
1343					return fmt.Errorf("SCT #%d contained %x, expected %x", i, sct, expected)
1344				}
1345			}
1346			return nil
1347		},
1348	}
1349	runClientTestTLS12(t, test)
1350
1351	// TLS 1.3 moved SCTs to the Certificate extensions and -serverinfo only
1352	// supports ServerHello extensions.
1353}
1354
1355func TestRenegotiationRejected(t *testing.T) {
1356	config := testConfig.Clone()
1357	test := &clientTest{
1358		name:                        "RenegotiationRejected",
1359		args:                        []string{"-state"},
1360		config:                      config,
1361		numRenegotiations:           1,
1362		renegotiationExpectedToFail: 1,
1363		checkRenegotiationError: func(renegotiationNum int, err error) error {
1364			if err == nil {
1365				return errors.New("expected error from renegotiation but got nil")
1366			}
1367			if !strings.Contains(err.Error(), "no renegotiation") {
1368				return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
1369			}
1370			return nil
1371		},
1372	}
1373	runClientTestTLS12(t, test)
1374}
1375
1376func TestRenegotiateOnce(t *testing.T) {
1377	config := testConfig.Clone()
1378	config.Renegotiation = RenegotiateOnceAsClient
1379
1380	test := &clientTest{
1381		name:              "RenegotiateOnce",
1382		args:              []string{"-state"},
1383		config:            config,
1384		numRenegotiations: 1,
1385	}
1386
1387	runClientTestTLS12(t, test)
1388}
1389
1390func TestRenegotiateTwice(t *testing.T) {
1391	config := testConfig.Clone()
1392	config.Renegotiation = RenegotiateFreelyAsClient
1393
1394	test := &clientTest{
1395		name:              "RenegotiateTwice",
1396		args:              []string{"-state"},
1397		config:            config,
1398		numRenegotiations: 2,
1399	}
1400
1401	runClientTestTLS12(t, test)
1402}
1403
1404func TestRenegotiateTwiceRejected(t *testing.T) {
1405	config := testConfig.Clone()
1406	config.Renegotiation = RenegotiateOnceAsClient
1407
1408	test := &clientTest{
1409		name:                        "RenegotiateTwiceRejected",
1410		args:                        []string{"-state"},
1411		config:                      config,
1412		numRenegotiations:           2,
1413		renegotiationExpectedToFail: 2,
1414		checkRenegotiationError: func(renegotiationNum int, err error) error {
1415			if renegotiationNum == 1 {
1416				return err
1417			}
1418
1419			if err == nil {
1420				return errors.New("expected error from renegotiation but got nil")
1421			}
1422			if !strings.Contains(err.Error(), "no renegotiation") {
1423				return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
1424			}
1425			return nil
1426		},
1427	}
1428
1429	runClientTestTLS12(t, test)
1430}
1431
1432func TestHandshakeClientExportKeyingMaterial(t *testing.T) {
1433	test := &clientTest{
1434		name:   "ExportKeyingMaterial",
1435		config: testConfig.Clone(),
1436		validate: func(state ConnectionState) error {
1437			if km, err := state.ExportKeyingMaterial("test", nil, 42); err != nil {
1438				return fmt.Errorf("ExportKeyingMaterial failed: %v", err)
1439			} else if len(km) != 42 {
1440				return fmt.Errorf("Got %d bytes from ExportKeyingMaterial, wanted %d", len(km), 42)
1441			}
1442			return nil
1443		},
1444	}
1445	runClientTestTLS10(t, test)
1446	runClientTestTLS12(t, test)
1447	runClientTestTLS13(t, test)
1448}
1449
1450var hostnameInSNITests = []struct {
1451	in, out string
1452}{
1453	// Opaque string
1454	{"", ""},
1455	{"localhost", "localhost"},
1456	{"foo, bar, baz and qux", "foo, bar, baz and qux"},
1457
1458	// DNS hostname
1459	{"golang.org", "golang.org"},
1460	{"golang.org.", "golang.org"},
1461
1462	// Literal IPv4 address
1463	{"1.2.3.4", ""},
1464
1465	// Literal IPv6 address
1466	{"::1", ""},
1467	{"::1%lo0", ""}, // with zone identifier
1468	{"[::1]", ""},   // as per RFC 5952 we allow the [] style as IPv6 literal
1469	{"[::1%lo0]", ""},
1470}
1471
1472func TestHostnameInSNI(t *testing.T) {
1473	for _, tt := range hostnameInSNITests {
1474		c, s := localPipe(t)
1475
1476		go func(host string) {
1477			Client(c, &Config{ServerName: host, InsecureSkipVerify: true}).Handshake()
1478		}(tt.in)
1479
1480		var header [5]byte
1481		if _, err := io.ReadFull(s, header[:]); err != nil {
1482			t.Fatal(err)
1483		}
1484		recordLen := int(header[3])<<8 | int(header[4])
1485
1486		record := make([]byte, recordLen)
1487		if _, err := io.ReadFull(s, record[:]); err != nil {
1488			t.Fatal(err)
1489		}
1490
1491		c.Close()
1492		s.Close()
1493
1494		var m clientHelloMsg
1495		if !m.unmarshal(record) {
1496			t.Errorf("unmarshaling ClientHello for %q failed", tt.in)
1497			continue
1498		}
1499		if tt.in != tt.out && m.serverName == tt.in {
1500			t.Errorf("prohibited %q found in ClientHello: %x", tt.in, record)
1501		}
1502		if m.serverName != tt.out {
1503			t.Errorf("expected %q not found in ClientHello: %x", tt.out, record)
1504		}
1505	}
1506}
1507
1508func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
1509	// This checks that the server can't select a cipher suite that the
1510	// client didn't offer. See #13174.
1511
1512	c, s := localPipe(t)
1513	errChan := make(chan error, 1)
1514
1515	go func() {
1516		client := Client(c, &Config{
1517			ServerName:   "foo",
1518			CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1519		})
1520		errChan <- client.Handshake()
1521	}()
1522
1523	var header [5]byte
1524	if _, err := io.ReadFull(s, header[:]); err != nil {
1525		t.Fatal(err)
1526	}
1527	recordLen := int(header[3])<<8 | int(header[4])
1528
1529	record := make([]byte, recordLen)
1530	if _, err := io.ReadFull(s, record); err != nil {
1531		t.Fatal(err)
1532	}
1533
1534	// Create a ServerHello that selects a different cipher suite than the
1535	// sole one that the client offered.
1536	serverHello := &serverHelloMsg{
1537		vers:        VersionTLS12,
1538		random:      make([]byte, 32),
1539		cipherSuite: TLS_RSA_WITH_AES_256_GCM_SHA384,
1540	}
1541	serverHelloBytes := mustMarshal(t, serverHello)
1542
1543	s.Write([]byte{
1544		byte(recordTypeHandshake),
1545		byte(VersionTLS12 >> 8),
1546		byte(VersionTLS12 & 0xff),
1547		byte(len(serverHelloBytes) >> 8),
1548		byte(len(serverHelloBytes)),
1549	})
1550	s.Write(serverHelloBytes)
1551	s.Close()
1552
1553	if err := <-errChan; !strings.Contains(err.Error(), "unconfigured cipher") {
1554		t.Fatalf("Expected error about unconfigured cipher suite but got %q", err)
1555	}
1556}
1557
1558func TestVerifyConnection(t *testing.T) {
1559	t.Run("TLSv12", func(t *testing.T) { testVerifyConnection(t, VersionTLS12) })
1560	t.Run("TLSv13", func(t *testing.T) { testVerifyConnection(t, VersionTLS13) })
1561}
1562
1563func testVerifyConnection(t *testing.T, version uint16) {
1564	checkFields := func(c ConnectionState, called *int, errorType string) error {
1565		if c.Version != version {
1566			return fmt.Errorf("%s: got Version %v, want %v", errorType, c.Version, version)
1567		}
1568		if c.HandshakeComplete {
1569			return fmt.Errorf("%s: got HandshakeComplete, want false", errorType)
1570		}
1571		if c.ServerName != "example.golang" {
1572			return fmt.Errorf("%s: got ServerName %s, want %s", errorType, c.ServerName, "example.golang")
1573		}
1574		if c.NegotiatedProtocol != "protocol1" {
1575			return fmt.Errorf("%s: got NegotiatedProtocol %s, want %s", errorType, c.NegotiatedProtocol, "protocol1")
1576		}
1577		if c.CipherSuite == 0 {
1578			return fmt.Errorf("%s: got CipherSuite 0, want non-zero", errorType)
1579		}
1580		wantDidResume := false
1581		if *called == 2 { // if this is the second time, then it should be a resumption
1582			wantDidResume = true
1583		}
1584		if c.DidResume != wantDidResume {
1585			return fmt.Errorf("%s: got DidResume %t, want %t", errorType, c.DidResume, wantDidResume)
1586		}
1587		return nil
1588	}
1589
1590	tests := []struct {
1591		name            string
1592		configureServer func(*Config, *int)
1593		configureClient func(*Config, *int)
1594	}{
1595		{
1596			name: "RequireAndVerifyClientCert",
1597			configureServer: func(config *Config, called *int) {
1598				config.ClientAuth = RequireAndVerifyClientCert
1599				config.VerifyConnection = func(c ConnectionState) error {
1600					*called++
1601					if l := len(c.PeerCertificates); l != 1 {
1602						return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
1603					}
1604					if len(c.VerifiedChains) == 0 {
1605						return fmt.Errorf("server: got len(VerifiedChains) = 0, wanted non-zero")
1606					}
1607					return checkFields(c, called, "server")
1608				}
1609			},
1610			configureClient: func(config *Config, called *int) {
1611				config.VerifyConnection = func(c ConnectionState) error {
1612					*called++
1613					if l := len(c.PeerCertificates); l != 1 {
1614						return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
1615					}
1616					if len(c.VerifiedChains) == 0 {
1617						return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
1618					}
1619					if c.DidResume {
1620						return nil
1621						// The SCTs and OCSP Response are dropped on resumption.
1622						// See http://golang.org/issue/39075.
1623					}
1624					if len(c.OCSPResponse) == 0 {
1625						return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
1626					}
1627					if len(c.SignedCertificateTimestamps) == 0 {
1628						return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
1629					}
1630					return checkFields(c, called, "client")
1631				}
1632			},
1633		},
1634		{
1635			name: "InsecureSkipVerify",
1636			configureServer: func(config *Config, called *int) {
1637				config.ClientAuth = RequireAnyClientCert
1638				config.InsecureSkipVerify = true
1639				config.VerifyConnection = func(c ConnectionState) error {
1640					*called++
1641					if l := len(c.PeerCertificates); l != 1 {
1642						return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
1643					}
1644					if c.VerifiedChains != nil {
1645						return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
1646					}
1647					return checkFields(c, called, "server")
1648				}
1649			},
1650			configureClient: func(config *Config, called *int) {
1651				config.InsecureSkipVerify = true
1652				config.VerifyConnection = func(c ConnectionState) error {
1653					*called++
1654					if l := len(c.PeerCertificates); l != 1 {
1655						return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
1656					}
1657					if c.VerifiedChains != nil {
1658						return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
1659					}
1660					if c.DidResume {
1661						return nil
1662						// The SCTs and OCSP Response are dropped on resumption.
1663						// See http://golang.org/issue/39075.
1664					}
1665					if len(c.OCSPResponse) == 0 {
1666						return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
1667					}
1668					if len(c.SignedCertificateTimestamps) == 0 {
1669						return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
1670					}
1671					return checkFields(c, called, "client")
1672				}
1673			},
1674		},
1675		{
1676			name: "NoClientCert",
1677			configureServer: func(config *Config, called *int) {
1678				config.ClientAuth = NoClientCert
1679				config.VerifyConnection = func(c ConnectionState) error {
1680					*called++
1681					return checkFields(c, called, "server")
1682				}
1683			},
1684			configureClient: func(config *Config, called *int) {
1685				config.VerifyConnection = func(c ConnectionState) error {
1686					*called++
1687					return checkFields(c, called, "client")
1688				}
1689			},
1690		},
1691		{
1692			name: "RequestClientCert",
1693			configureServer: func(config *Config, called *int) {
1694				config.ClientAuth = RequestClientCert
1695				config.VerifyConnection = func(c ConnectionState) error {
1696					*called++
1697					return checkFields(c, called, "server")
1698				}
1699			},
1700			configureClient: func(config *Config, called *int) {
1701				config.Certificates = nil // clear the client cert
1702				config.VerifyConnection = func(c ConnectionState) error {
1703					*called++
1704					if l := len(c.PeerCertificates); l != 1 {
1705						return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
1706					}
1707					if len(c.VerifiedChains) == 0 {
1708						return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
1709					}
1710					if c.DidResume {
1711						return nil
1712						// The SCTs and OCSP Response are dropped on resumption.
1713						// See http://golang.org/issue/39075.
1714					}
1715					if len(c.OCSPResponse) == 0 {
1716						return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
1717					}
1718					if len(c.SignedCertificateTimestamps) == 0 {
1719						return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
1720					}
1721					return checkFields(c, called, "client")
1722				}
1723			},
1724		},
1725	}
1726	for _, test := range tests {
1727		issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
1728		if err != nil {
1729			panic(err)
1730		}
1731		rootCAs := x509.NewCertPool()
1732		rootCAs.AddCert(issuer)
1733
1734		var serverCalled, clientCalled int
1735
1736		serverConfig := &Config{
1737			MaxVersion:   version,
1738			Certificates: []Certificate{testConfig.Certificates[0]},
1739			ClientCAs:    rootCAs,
1740			NextProtos:   []string{"protocol1"},
1741		}
1742		serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
1743		serverConfig.Certificates[0].OCSPStaple = []byte("dummy ocsp")
1744		test.configureServer(serverConfig, &serverCalled)
1745
1746		clientConfig := &Config{
1747			MaxVersion:         version,
1748			ClientSessionCache: NewLRUClientSessionCache(32),
1749			RootCAs:            rootCAs,
1750			ServerName:         "example.golang",
1751			Certificates:       []Certificate{testConfig.Certificates[0]},
1752			NextProtos:         []string{"protocol1"},
1753		}
1754		test.configureClient(clientConfig, &clientCalled)
1755
1756		testHandshakeState := func(name string, didResume bool) {
1757			_, hs, err := testHandshake(t, clientConfig, serverConfig)
1758			if err != nil {
1759				t.Fatalf("%s: handshake failed: %s", name, err)
1760			}
1761			if hs.DidResume != didResume {
1762				t.Errorf("%s: resumed: %v, expected: %v", name, hs.DidResume, didResume)
1763			}
1764			wantCalled := 1
1765			if didResume {
1766				wantCalled = 2 // resumption would mean this is the second time it was called in this test
1767			}
1768			if clientCalled != wantCalled {
1769				t.Errorf("%s: expected client VerifyConnection called %d times, did %d times", name, wantCalled, clientCalled)
1770			}
1771			if serverCalled != wantCalled {
1772				t.Errorf("%s: expected server VerifyConnection called %d times, did %d times", name, wantCalled, serverCalled)
1773			}
1774		}
1775		testHandshakeState(fmt.Sprintf("%s-FullHandshake", test.name), false)
1776		testHandshakeState(fmt.Sprintf("%s-Resumption", test.name), true)
1777	}
1778}
1779
1780func TestVerifyPeerCertificate(t *testing.T) {
1781	t.Run("TLSv12", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS12) })
1782	t.Run("TLSv13", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS13) })
1783}
1784
1785func testVerifyPeerCertificate(t *testing.T, version uint16) {
1786	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
1787	if err != nil {
1788		panic(err)
1789	}
1790
1791	rootCAs := x509.NewCertPool()
1792	rootCAs.AddCert(issuer)
1793
1794	now := func() time.Time { return time.Unix(1476984729, 0) }
1795
1796	sentinelErr := errors.New("TestVerifyPeerCertificate")
1797
1798	verifyPeerCertificateCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1799		if l := len(rawCerts); l != 1 {
1800			return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
1801		}
1802		if len(validatedChains) == 0 {
1803			return errors.New("got len(validatedChains) = 0, wanted non-zero")
1804		}
1805		*called = true
1806		return nil
1807	}
1808	verifyConnectionCallback := func(called *bool, isClient bool, c ConnectionState) error {
1809		if l := len(c.PeerCertificates); l != 1 {
1810			return fmt.Errorf("got len(PeerCertificates) = %d, wanted 1", l)
1811		}
1812		if len(c.VerifiedChains) == 0 {
1813			return fmt.Errorf("got len(VerifiedChains) = 0, wanted non-zero")
1814		}
1815		if isClient && len(c.OCSPResponse) == 0 {
1816			return fmt.Errorf("got len(OCSPResponse) = 0, wanted non-zero")
1817		}
1818		*called = true
1819		return nil
1820	}
1821
1822	tests := []struct {
1823		configureServer func(*Config, *bool)
1824		configureClient func(*Config, *bool)
1825		validate        func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error)
1826	}{
1827		{
1828			configureServer: func(config *Config, called *bool) {
1829				config.InsecureSkipVerify = false
1830				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1831					return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
1832				}
1833			},
1834			configureClient: func(config *Config, called *bool) {
1835				config.InsecureSkipVerify = false
1836				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1837					return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
1838				}
1839			},
1840			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1841				if clientErr != nil {
1842					t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
1843				}
1844				if serverErr != nil {
1845					t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
1846				}
1847				if !clientCalled {
1848					t.Errorf("test[%d]: client did not call callback", testNo)
1849				}
1850				if !serverCalled {
1851					t.Errorf("test[%d]: server did not call callback", testNo)
1852				}
1853			},
1854		},
1855		{
1856			configureServer: func(config *Config, called *bool) {
1857				config.InsecureSkipVerify = false
1858				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1859					return sentinelErr
1860				}
1861			},
1862			configureClient: func(config *Config, called *bool) {
1863				config.VerifyPeerCertificate = nil
1864			},
1865			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1866				if serverErr != sentinelErr {
1867					t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
1868				}
1869			},
1870		},
1871		{
1872			configureServer: func(config *Config, called *bool) {
1873				config.InsecureSkipVerify = false
1874			},
1875			configureClient: func(config *Config, called *bool) {
1876				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1877					return sentinelErr
1878				}
1879			},
1880			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1881				if clientErr != sentinelErr {
1882					t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
1883				}
1884			},
1885		},
1886		{
1887			configureServer: func(config *Config, called *bool) {
1888				config.InsecureSkipVerify = false
1889			},
1890			configureClient: func(config *Config, called *bool) {
1891				config.InsecureSkipVerify = true
1892				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1893					if l := len(rawCerts); l != 1 {
1894						return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
1895					}
1896					// With InsecureSkipVerify set, this
1897					// callback should still be called but
1898					// validatedChains must be empty.
1899					if l := len(validatedChains); l != 0 {
1900						return fmt.Errorf("got len(validatedChains) = %d, wanted zero", l)
1901					}
1902					*called = true
1903					return nil
1904				}
1905			},
1906			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1907				if clientErr != nil {
1908					t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
1909				}
1910				if serverErr != nil {
1911					t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
1912				}
1913				if !clientCalled {
1914					t.Errorf("test[%d]: client did not call callback", testNo)
1915				}
1916			},
1917		},
1918		{
1919			configureServer: func(config *Config, called *bool) {
1920				config.InsecureSkipVerify = false
1921				config.VerifyConnection = func(c ConnectionState) error {
1922					return verifyConnectionCallback(called, false, c)
1923				}
1924			},
1925			configureClient: func(config *Config, called *bool) {
1926				config.InsecureSkipVerify = false
1927				config.VerifyConnection = func(c ConnectionState) error {
1928					return verifyConnectionCallback(called, true, c)
1929				}
1930			},
1931			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1932				if clientErr != nil {
1933					t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
1934				}
1935				if serverErr != nil {
1936					t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
1937				}
1938				if !clientCalled {
1939					t.Errorf("test[%d]: client did not call callback", testNo)
1940				}
1941				if !serverCalled {
1942					t.Errorf("test[%d]: server did not call callback", testNo)
1943				}
1944			},
1945		},
1946		{
1947			configureServer: func(config *Config, called *bool) {
1948				config.InsecureSkipVerify = false
1949				config.VerifyConnection = func(c ConnectionState) error {
1950					return sentinelErr
1951				}
1952			},
1953			configureClient: func(config *Config, called *bool) {
1954				config.InsecureSkipVerify = false
1955				config.VerifyConnection = nil
1956			},
1957			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1958				if serverErr != sentinelErr {
1959					t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
1960				}
1961			},
1962		},
1963		{
1964			configureServer: func(config *Config, called *bool) {
1965				config.InsecureSkipVerify = false
1966				config.VerifyConnection = nil
1967			},
1968			configureClient: func(config *Config, called *bool) {
1969				config.InsecureSkipVerify = false
1970				config.VerifyConnection = func(c ConnectionState) error {
1971					return sentinelErr
1972				}
1973			},
1974			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1975				if clientErr != sentinelErr {
1976					t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
1977				}
1978			},
1979		},
1980		{
1981			configureServer: func(config *Config, called *bool) {
1982				config.InsecureSkipVerify = false
1983				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1984					return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
1985				}
1986				config.VerifyConnection = func(c ConnectionState) error {
1987					return sentinelErr
1988				}
1989			},
1990			configureClient: func(config *Config, called *bool) {
1991				config.InsecureSkipVerify = false
1992				config.VerifyPeerCertificate = nil
1993				config.VerifyConnection = nil
1994			},
1995			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1996				if serverErr != sentinelErr {
1997					t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
1998				}
1999				if !serverCalled {
2000					t.Errorf("test[%d]: server did not call callback", testNo)
2001				}
2002			},
2003		},
2004		{
2005			configureServer: func(config *Config, called *bool) {
2006				config.InsecureSkipVerify = false
2007				config.VerifyPeerCertificate = nil
2008				config.VerifyConnection = nil
2009			},
2010			configureClient: func(config *Config, called *bool) {
2011				config.InsecureSkipVerify = false
2012				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
2013					return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
2014				}
2015				config.VerifyConnection = func(c ConnectionState) error {
2016					return sentinelErr
2017				}
2018			},
2019			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
2020				if clientErr != sentinelErr {
2021					t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
2022				}
2023				if !clientCalled {
2024					t.Errorf("test[%d]: client did not call callback", testNo)
2025				}
2026			},
2027		},
2028	}
2029
2030	for i, test := range tests {
2031		c, s := localPipe(t)
2032		done := make(chan error)
2033
2034		var clientCalled, serverCalled bool
2035
2036		go func() {
2037			config := testConfig.Clone()
2038			config.ServerName = "example.golang"
2039			config.ClientAuth = RequireAndVerifyClientCert
2040			config.ClientCAs = rootCAs
2041			config.Time = now
2042			config.MaxVersion = version
2043			config.Certificates = make([]Certificate, 1)
2044			config.Certificates[0].Certificate = [][]byte{testRSACertificate}
2045			config.Certificates[0].PrivateKey = testRSAPrivateKey
2046			config.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
2047			config.Certificates[0].OCSPStaple = []byte("dummy ocsp")
2048			test.configureServer(config, &serverCalled)
2049
2050			err = Server(s, config).Handshake()
2051			s.Close()
2052			done <- err
2053		}()
2054
2055		config := testConfig.Clone()
2056		config.ServerName = "example.golang"
2057		config.RootCAs = rootCAs
2058		config.Time = now
2059		config.MaxVersion = version
2060		test.configureClient(config, &clientCalled)
2061		clientErr := Client(c, config).Handshake()
2062		c.Close()
2063		serverErr := <-done
2064
2065		test.validate(t, i, clientCalled, serverCalled, clientErr, serverErr)
2066	}
2067}
2068
2069// brokenConn wraps a net.Conn and causes all Writes after a certain number to
2070// fail with brokenConnErr.
2071type brokenConn struct {
2072	net.Conn
2073
2074	// breakAfter is the number of successful writes that will be allowed
2075	// before all subsequent writes fail.
2076	breakAfter int
2077
2078	// numWrites is the number of writes that have been done.
2079	numWrites int
2080}
2081
2082// brokenConnErr is the error that brokenConn returns once exhausted.
2083var brokenConnErr = errors.New("too many writes to brokenConn")
2084
2085func (b *brokenConn) Write(data []byte) (int, error) {
2086	if b.numWrites >= b.breakAfter {
2087		return 0, brokenConnErr
2088	}
2089
2090	b.numWrites++
2091	return b.Conn.Write(data)
2092}
2093
2094func TestFailedWrite(t *testing.T) {
2095	// Test that a write error during the handshake is returned.
2096	for _, breakAfter := range []int{0, 1} {
2097		c, s := localPipe(t)
2098		done := make(chan bool)
2099
2100		go func() {
2101			Server(s, testConfig).Handshake()
2102			s.Close()
2103			done <- true
2104		}()
2105
2106		brokenC := &brokenConn{Conn: c, breakAfter: breakAfter}
2107		err := Client(brokenC, testConfig).Handshake()
2108		if err != brokenConnErr {
2109			t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err)
2110		}
2111		brokenC.Close()
2112
2113		<-done
2114	}
2115}
2116
2117// writeCountingConn wraps a net.Conn and counts the number of Write calls.
2118type writeCountingConn struct {
2119	net.Conn
2120
2121	// numWrites is the number of writes that have been done.
2122	numWrites int
2123}
2124
2125func (wcc *writeCountingConn) Write(data []byte) (int, error) {
2126	wcc.numWrites++
2127	return wcc.Conn.Write(data)
2128}
2129
2130func TestBuffering(t *testing.T) {
2131	t.Run("TLSv12", func(t *testing.T) { testBuffering(t, VersionTLS12) })
2132	t.Run("TLSv13", func(t *testing.T) { testBuffering(t, VersionTLS13) })
2133}
2134
2135func testBuffering(t *testing.T, version uint16) {
2136	c, s := localPipe(t)
2137	done := make(chan bool)
2138
2139	clientWCC := &writeCountingConn{Conn: c}
2140	serverWCC := &writeCountingConn{Conn: s}
2141
2142	go func() {
2143		config := testConfig.Clone()
2144		config.MaxVersion = version
2145		Server(serverWCC, config).Handshake()
2146		serverWCC.Close()
2147		done <- true
2148	}()
2149
2150	err := Client(clientWCC, testConfig).Handshake()
2151	if err != nil {
2152		t.Fatal(err)
2153	}
2154	clientWCC.Close()
2155	<-done
2156
2157	var expectedClient, expectedServer int
2158	if version == VersionTLS13 {
2159		expectedClient = 2
2160		expectedServer = 1
2161	} else {
2162		expectedClient = 2
2163		expectedServer = 2
2164	}
2165
2166	if n := clientWCC.numWrites; n != expectedClient {
2167		t.Errorf("expected client handshake to complete with %d writes, but saw %d", expectedClient, n)
2168	}
2169
2170	if n := serverWCC.numWrites; n != expectedServer {
2171		t.Errorf("expected server handshake to complete with %d writes, but saw %d", expectedServer, n)
2172	}
2173}
2174
2175func TestAlertFlushing(t *testing.T) {
2176	c, s := localPipe(t)
2177	done := make(chan bool)
2178
2179	clientWCC := &writeCountingConn{Conn: c}
2180	serverWCC := &writeCountingConn{Conn: s}
2181
2182	serverConfig := testConfig.Clone()
2183
2184	// Cause a signature-time error
2185	brokenKey := rsa.PrivateKey{PublicKey: testRSAPrivateKey.PublicKey}
2186	brokenKey.D = big.NewInt(42)
2187	serverConfig.Certificates = []Certificate{{
2188		Certificate: [][]byte{testRSACertificate},
2189		PrivateKey:  &brokenKey,
2190	}}
2191
2192	go func() {
2193		Server(serverWCC, serverConfig).Handshake()
2194		serverWCC.Close()
2195		done <- true
2196	}()
2197
2198	err := Client(clientWCC, testConfig).Handshake()
2199	if err == nil {
2200		t.Fatal("client unexpectedly returned no error")
2201	}
2202
2203	const expectedError = "remote error: tls: internal error"
2204	if e := err.Error(); !strings.Contains(e, expectedError) {
2205		t.Fatalf("expected to find %q in error but error was %q", expectedError, e)
2206	}
2207	clientWCC.Close()
2208	<-done
2209
2210	if n := serverWCC.numWrites; n != 1 {
2211		t.Errorf("expected server handshake to complete with one write, but saw %d", n)
2212	}
2213}
2214
2215func TestHandshakeRace(t *testing.T) {
2216	if testing.Short() {
2217		t.Skip("skipping in -short mode")
2218	}
2219	t.Parallel()
2220	// This test races a Read and Write to try and complete a handshake in
2221	// order to provide some evidence that there are no races or deadlocks
2222	// in the handshake locking.
2223	for i := 0; i < 32; i++ {
2224		c, s := localPipe(t)
2225
2226		go func() {
2227			server := Server(s, testConfig)
2228			if err := server.Handshake(); err != nil {
2229				panic(err)
2230			}
2231
2232			var request [1]byte
2233			if n, err := server.Read(request[:]); err != nil || n != 1 {
2234				panic(err)
2235			}
2236
2237			server.Write(request[:])
2238			server.Close()
2239		}()
2240
2241		startWrite := make(chan struct{})
2242		startRead := make(chan struct{})
2243		readDone := make(chan struct{}, 1)
2244
2245		client := Client(c, testConfig)
2246		go func() {
2247			<-startWrite
2248			var request [1]byte
2249			client.Write(request[:])
2250		}()
2251
2252		go func() {
2253			<-startRead
2254			var reply [1]byte
2255			if _, err := io.ReadFull(client, reply[:]); err != nil {
2256				panic(err)
2257			}
2258			c.Close()
2259			readDone <- struct{}{}
2260		}()
2261
2262		if i&1 == 1 {
2263			startWrite <- struct{}{}
2264			startRead <- struct{}{}
2265		} else {
2266			startRead <- struct{}{}
2267			startWrite <- struct{}{}
2268		}
2269		<-readDone
2270	}
2271}
2272
2273var getClientCertificateTests = []struct {
2274	setup               func(*Config, *Config)
2275	expectedClientError string
2276	verify              func(*testing.T, int, *ConnectionState)
2277}{
2278	{
2279		func(clientConfig, serverConfig *Config) {
2280			// Returning a Certificate with no certificate data
2281			// should result in an empty message being sent to the
2282			// server.
2283			serverConfig.ClientCAs = nil
2284			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
2285				if len(cri.SignatureSchemes) == 0 {
2286					panic("empty SignatureSchemes")
2287				}
2288				if len(cri.AcceptableCAs) != 0 {
2289					panic("AcceptableCAs should have been empty")
2290				}
2291				return new(Certificate), nil
2292			}
2293		},
2294		"",
2295		func(t *testing.T, testNum int, cs *ConnectionState) {
2296			if l := len(cs.PeerCertificates); l != 0 {
2297				t.Errorf("#%d: expected no certificates but got %d", testNum, l)
2298			}
2299		},
2300	},
2301	{
2302		func(clientConfig, serverConfig *Config) {
2303			// With TLS 1.1, the SignatureSchemes should be
2304			// synthesised from the supported certificate types.
2305			clientConfig.MaxVersion = VersionTLS11
2306			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
2307				if len(cri.SignatureSchemes) == 0 {
2308					panic("empty SignatureSchemes")
2309				}
2310				return new(Certificate), nil
2311			}
2312		},
2313		"",
2314		func(t *testing.T, testNum int, cs *ConnectionState) {
2315			if l := len(cs.PeerCertificates); l != 0 {
2316				t.Errorf("#%d: expected no certificates but got %d", testNum, l)
2317			}
2318		},
2319	},
2320	{
2321		func(clientConfig, serverConfig *Config) {
2322			// Returning an error should abort the handshake with
2323			// that error.
2324			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
2325				return nil, errors.New("GetClientCertificate")
2326			}
2327		},
2328		"GetClientCertificate",
2329		func(t *testing.T, testNum int, cs *ConnectionState) {
2330		},
2331	},
2332	{
2333		func(clientConfig, serverConfig *Config) {
2334			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
2335				if len(cri.AcceptableCAs) == 0 {
2336					panic("empty AcceptableCAs")
2337				}
2338				cert := &Certificate{
2339					Certificate: [][]byte{testRSACertificate},
2340					PrivateKey:  testRSAPrivateKey,
2341				}
2342				return cert, nil
2343			}
2344		},
2345		"",
2346		func(t *testing.T, testNum int, cs *ConnectionState) {
2347			if len(cs.VerifiedChains) == 0 {
2348				t.Errorf("#%d: expected some verified chains, but found none", testNum)
2349			}
2350		},
2351	},
2352}
2353
2354func TestGetClientCertificate(t *testing.T) {
2355	t.Run("TLSv12", func(t *testing.T) { testGetClientCertificate(t, VersionTLS12) })
2356	t.Run("TLSv13", func(t *testing.T) { testGetClientCertificate(t, VersionTLS13) })
2357}
2358
2359func testGetClientCertificate(t *testing.T, version uint16) {
2360	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
2361	if err != nil {
2362		panic(err)
2363	}
2364
2365	for i, test := range getClientCertificateTests {
2366		serverConfig := testConfig.Clone()
2367		serverConfig.ClientAuth = VerifyClientCertIfGiven
2368		serverConfig.RootCAs = x509.NewCertPool()
2369		serverConfig.RootCAs.AddCert(issuer)
2370		serverConfig.ClientCAs = serverConfig.RootCAs
2371		serverConfig.Time = func() time.Time { return time.Unix(1476984729, 0) }
2372		serverConfig.MaxVersion = version
2373
2374		clientConfig := testConfig.Clone()
2375		clientConfig.MaxVersion = version
2376
2377		test.setup(clientConfig, serverConfig)
2378
2379		type serverResult struct {
2380			cs  ConnectionState
2381			err error
2382		}
2383
2384		c, s := localPipe(t)
2385		done := make(chan serverResult)
2386
2387		go func() {
2388			defer s.Close()
2389			server := Server(s, serverConfig)
2390			err := server.Handshake()
2391
2392			var cs ConnectionState
2393			if err == nil {
2394				cs = server.ConnectionState()
2395			}
2396			done <- serverResult{cs, err}
2397		}()
2398
2399		clientErr := Client(c, clientConfig).Handshake()
2400		c.Close()
2401
2402		result := <-done
2403
2404		if clientErr != nil {
2405			if len(test.expectedClientError) == 0 {
2406				t.Errorf("#%d: client error: %v", i, clientErr)
2407			} else if got := clientErr.Error(); got != test.expectedClientError {
2408				t.Errorf("#%d: expected client error %q, but got %q", i, test.expectedClientError, got)
2409			} else {
2410				test.verify(t, i, &result.cs)
2411			}
2412		} else if len(test.expectedClientError) > 0 {
2413			t.Errorf("#%d: expected client error %q, but got no error", i, test.expectedClientError)
2414		} else if err := result.err; err != nil {
2415			t.Errorf("#%d: server error: %v", i, err)
2416		} else {
2417			test.verify(t, i, &result.cs)
2418		}
2419	}
2420}
2421
2422func TestRSAPSSKeyError(t *testing.T) {
2423	// crypto/tls does not support the rsa_pss_pss_* SignatureSchemes. If support for
2424	// public keys with OID RSASSA-PSS is added to crypto/x509, they will be misused with
2425	// the rsa_pss_rsae_* SignatureSchemes. Assert that RSASSA-PSS certificates don't
2426	// parse, or that they don't carry *rsa.PublicKey keys.
2427	b, _ := pem.Decode([]byte(`
2428-----BEGIN CERTIFICATE-----
2429MIIDZTCCAhygAwIBAgIUCF2x0FyTgZG0CC9QTDjGWkB5vgEwPgYJKoZIhvcNAQEK
2430MDGgDTALBglghkgBZQMEAgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQC
2431AgDeMBIxEDAOBgNVBAMMB1JTQS1QU1MwHhcNMTgwNjI3MjI0NDM2WhcNMTgwNzI3
2432MjI0NDM2WjASMRAwDgYDVQQDDAdSU0EtUFNTMIIBIDALBgkqhkiG9w0BAQoDggEP
2433ADCCAQoCggEBANxDm0f76JdI06YzsjB3AmmjIYkwUEGxePlafmIASFjDZl/elD0Z
2434/a7xLX468b0qGxLS5al7XCcEprSdsDR6DF5L520+pCbpfLyPOjuOvGmk9KzVX4x5
2435b05YXYuXdsQ0Kjxcx2i3jjCday6scIhMJVgBZxTEyMj1thPQM14SHzKCd/m6HmCL
2436QmswpH2yMAAcBRWzRpp/vdH5DeOJEB3aelq7094no731mrLUCHRiZ1htq8BDB3ou
2437czwqgwspbqZ4dnMXl2MvfySQ5wJUxQwILbiuAKO2lVVPUbFXHE9pgtznNoPvKwQT
2438JNcX8ee8WIZc2SEGzofjk3NpjR+2ADB2u3sCAwEAAaNTMFEwHQYDVR0OBBYEFNEz
2439AdyJ2f+fU+vSCS6QzohnOnprMB8GA1UdIwQYMBaAFNEzAdyJ2f+fU+vSCS6Qzohn
2440OnprMA8GA1UdEwEB/wQFMAMBAf8wPgYJKoZIhvcNAQEKMDGgDTALBglghkgBZQME
2441AgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQCAgDeA4IBAQCjEdrR5aab
2442sZmCwrMeKidXgfkmWvfuLDE+TCbaqDZp7BMWcMQXT9O0UoUT5kqgKj2ARm2pEW0Z
2443H3Z1vj3bbds72qcDIJXp+l0fekyLGeCrX/CbgnMZXEP7+/+P416p34ChR1Wz4dU1
2444KD3gdsUuTKKeMUog3plxlxQDhRQmiL25ygH1LmjLd6dtIt0GVRGr8lj3euVeprqZ
2445bZ3Uq5eLfsn8oPgfC57gpO6yiN+UURRTlK3bgYvLh4VWB3XXk9UaQZ7Mq1tpXjoD
2446HYFybkWzibkZp4WRo+Fa28rirH+/wHt0vfeN7UCceURZEx4JaxIIfe4ku7uDRhJi
2447RwBA9Xk1KBNF
2448-----END CERTIFICATE-----`))
2449	if b == nil {
2450		t.Fatal("Failed to decode certificate")
2451	}
2452	cert, err := x509.ParseCertificate(b.Bytes)
2453	if err != nil {
2454		return
2455	}
2456	if _, ok := cert.PublicKey.(*rsa.PublicKey); ok {
2457		t.Error("A RSASSA-PSS certificate was parsed like a PKCS#1 v1.5 one, and it will be mistakenly used with rsa_pss_rsae_* signature algorithms")
2458	}
2459}
2460
2461func TestCloseClientConnectionOnIdleServer(t *testing.T) {
2462	clientConn, serverConn := localPipe(t)
2463	client := Client(clientConn, testConfig.Clone())
2464	go func() {
2465		var b [1]byte
2466		serverConn.Read(b[:])
2467		client.Close()
2468	}()
2469	client.SetWriteDeadline(time.Now().Add(time.Minute))
2470	err := client.Handshake()
2471	if err != nil {
2472		if err, ok := err.(net.Error); ok && err.Timeout() {
2473			t.Errorf("Expected a closed network connection error but got '%s'", err.Error())
2474		}
2475	} else {
2476		t.Errorf("Error expected, but no error returned")
2477	}
2478}
2479
2480func testDowngradeCanary(t *testing.T, clientVersion, serverVersion uint16) error {
2481	defer func() { testingOnlyForceDowngradeCanary = false }()
2482	testingOnlyForceDowngradeCanary = true
2483
2484	clientConfig := testConfig.Clone()
2485	clientConfig.MaxVersion = clientVersion
2486	serverConfig := testConfig.Clone()
2487	serverConfig.MaxVersion = serverVersion
2488	_, _, err := testHandshake(t, clientConfig, serverConfig)
2489	return err
2490}
2491
2492func TestDowngradeCanary(t *testing.T) {
2493	if err := testDowngradeCanary(t, VersionTLS13, VersionTLS12); err == nil {
2494		t.Errorf("downgrade from TLS 1.3 to TLS 1.2 was not detected")
2495	}
2496	if testing.Short() {
2497		t.Skip("skipping the rest of the checks in short mode")
2498	}
2499	if err := testDowngradeCanary(t, VersionTLS13, VersionTLS11); err == nil {
2500		t.Errorf("downgrade from TLS 1.3 to TLS 1.1 was not detected")
2501	}
2502	if err := testDowngradeCanary(t, VersionTLS13, VersionTLS10); err == nil {
2503		t.Errorf("downgrade from TLS 1.3 to TLS 1.0 was not detected")
2504	}
2505	if err := testDowngradeCanary(t, VersionTLS12, VersionTLS11); err == nil {
2506		t.Errorf("downgrade from TLS 1.2 to TLS 1.1 was not detected")
2507	}
2508	if err := testDowngradeCanary(t, VersionTLS12, VersionTLS10); err == nil {
2509		t.Errorf("downgrade from TLS 1.2 to TLS 1.0 was not detected")
2510	}
2511	if err := testDowngradeCanary(t, VersionTLS13, VersionTLS13); err != nil {
2512		t.Errorf("server unexpectedly sent downgrade canary for TLS 1.3")
2513	}
2514	if err := testDowngradeCanary(t, VersionTLS12, VersionTLS12); err != nil {
2515		t.Errorf("client didn't ignore expected TLS 1.2 canary")
2516	}
2517	if err := testDowngradeCanary(t, VersionTLS11, VersionTLS11); err != nil {
2518		t.Errorf("client unexpectedly reacted to a canary in TLS 1.1")
2519	}
2520	if err := testDowngradeCanary(t, VersionTLS10, VersionTLS10); err != nil {
2521		t.Errorf("client unexpectedly reacted to a canary in TLS 1.0")
2522	}
2523}
2524
2525func TestResumptionKeepsOCSPAndSCT(t *testing.T) {
2526	t.Run("TLSv12", func(t *testing.T) { testResumptionKeepsOCSPAndSCT(t, VersionTLS12) })
2527	t.Run("TLSv13", func(t *testing.T) { testResumptionKeepsOCSPAndSCT(t, VersionTLS13) })
2528}
2529
2530func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) {
2531	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
2532	if err != nil {
2533		t.Fatalf("failed to parse test issuer")
2534	}
2535	roots := x509.NewCertPool()
2536	roots.AddCert(issuer)
2537	clientConfig := &Config{
2538		MaxVersion:         ver,
2539		ClientSessionCache: NewLRUClientSessionCache(32),
2540		ServerName:         "example.golang",
2541		RootCAs:            roots,
2542	}
2543	serverConfig := testConfig.Clone()
2544	serverConfig.MaxVersion = ver
2545	serverConfig.Certificates[0].OCSPStaple = []byte{1, 2, 3}
2546	serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{{4, 5, 6}}
2547
2548	_, ccs, err := testHandshake(t, clientConfig, serverConfig)
2549	if err != nil {
2550		t.Fatalf("handshake failed: %s", err)
2551	}
2552	// after a new session we expect to see OCSPResponse and
2553	// SignedCertificateTimestamps populated as usual
2554	if !bytes.Equal(ccs.OCSPResponse, serverConfig.Certificates[0].OCSPStaple) {
2555		t.Errorf("client ConnectionState contained unexpected OCSPResponse: wanted %v, got %v",
2556			serverConfig.Certificates[0].OCSPStaple, ccs.OCSPResponse)
2557	}
2558	if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, serverConfig.Certificates[0].SignedCertificateTimestamps) {
2559		t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps: wanted %v, got %v",
2560			serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps)
2561	}
2562
2563	// if the server doesn't send any SCTs, repopulate the old SCTs
2564	oldSCTs := serverConfig.Certificates[0].SignedCertificateTimestamps
2565	serverConfig.Certificates[0].SignedCertificateTimestamps = nil
2566	_, ccs, err = testHandshake(t, clientConfig, serverConfig)
2567	if err != nil {
2568		t.Fatalf("handshake failed: %s", err)
2569	}
2570	if !ccs.DidResume {
2571		t.Fatalf("expected session to be resumed")
2572	}
2573	// after a resumed session we also expect to see OCSPResponse
2574	// and SignedCertificateTimestamps populated
2575	if !bytes.Equal(ccs.OCSPResponse, serverConfig.Certificates[0].OCSPStaple) {
2576		t.Errorf("client ConnectionState contained unexpected OCSPResponse after resumption: wanted %v, got %v",
2577			serverConfig.Certificates[0].OCSPStaple, ccs.OCSPResponse)
2578	}
2579	if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, oldSCTs) {
2580		t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps after resumption: wanted %v, got %v",
2581			oldSCTs, ccs.SignedCertificateTimestamps)
2582	}
2583
2584	//  Only test overriding the SCTs for TLS 1.2, since in 1.3
2585	// the server won't send the message containing them
2586	if ver == VersionTLS13 {
2587		return
2588	}
2589
2590	// if the server changes the SCTs it sends, they should override the saved SCTs
2591	serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{{7, 8, 9}}
2592	_, ccs, err = testHandshake(t, clientConfig, serverConfig)
2593	if err != nil {
2594		t.Fatalf("handshake failed: %s", err)
2595	}
2596	if !ccs.DidResume {
2597		t.Fatalf("expected session to be resumed")
2598	}
2599	if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, serverConfig.Certificates[0].SignedCertificateTimestamps) {
2600		t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps after resumption: wanted %v, got %v",
2601			serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps)
2602	}
2603}
2604
2605// TestClientHandshakeContextCancellation tests that canceling
2606// the context given to the client side conn.HandshakeContext
2607// interrupts the in-progress handshake.
2608func TestClientHandshakeContextCancellation(t *testing.T) {
2609	c, s := localPipe(t)
2610	ctx, cancel := context.WithCancel(context.Background())
2611	unblockServer := make(chan struct{})
2612	defer close(unblockServer)
2613	go func() {
2614		cancel()
2615		<-unblockServer
2616		_ = s.Close()
2617	}()
2618	cli := Client(c, testConfig)
2619	// Initiates client side handshake, which will block until the client hello is read
2620	// by the server, unless the cancellation works.
2621	err := cli.HandshakeContext(ctx)
2622	if err == nil {
2623		t.Fatal("Client handshake did not error when the context was canceled")
2624	}
2625	if err != context.Canceled {
2626		t.Errorf("Unexpected client handshake error: %v", err)
2627	}
2628	if runtime.GOARCH == "wasm" {
2629		t.Skip("conn.Close does not error as expected when called multiple times on WASM")
2630	}
2631	err = cli.Close()
2632	if err == nil {
2633		t.Error("Client connection was not closed when the context was canceled")
2634	}
2635}
2636
2637// TestTLS13OnlyClientHelloCipherSuite tests that when a client states that
2638// it only supports TLS 1.3, it correctly advertises only TLS 1.3 ciphers.
2639func TestTLS13OnlyClientHelloCipherSuite(t *testing.T) {
2640	tls13Tests := []struct {
2641		name    string
2642		ciphers []uint16
2643	}{
2644		{
2645			name:    "nil",
2646			ciphers: nil,
2647		},
2648		{
2649			name:    "empty",
2650			ciphers: []uint16{},
2651		},
2652		{
2653			name:    "some TLS 1.2 cipher",
2654			ciphers: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
2655		},
2656		{
2657			name:    "some TLS 1.3 cipher",
2658			ciphers: []uint16{TLS_AES_128_GCM_SHA256},
2659		},
2660		{
2661			name:    "some TLS 1.2 and 1.3 ciphers",
2662			ciphers: []uint16{TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_AES_256_GCM_SHA384},
2663		},
2664	}
2665	for _, tt := range tls13Tests {
2666		tt := tt
2667		t.Run(tt.name, func(t *testing.T) {
2668			t.Parallel()
2669			testTLS13OnlyClientHelloCipherSuite(t, tt.ciphers)
2670		})
2671	}
2672}
2673
2674func testTLS13OnlyClientHelloCipherSuite(t *testing.T, ciphers []uint16) {
2675	serverConfig := &Config{
2676		Certificates: testConfig.Certificates,
2677		GetConfigForClient: func(chi *ClientHelloInfo) (*Config, error) {
2678			if len(chi.CipherSuites) != len(defaultCipherSuitesTLS13NoAES) {
2679				t.Errorf("only TLS 1.3 suites should be advertised, got=%x", chi.CipherSuites)
2680			} else {
2681				for i := range defaultCipherSuitesTLS13NoAES {
2682					if want, got := defaultCipherSuitesTLS13NoAES[i], chi.CipherSuites[i]; want != got {
2683						t.Errorf("cipher at index %d does not match, want=%x, got=%x", i, want, got)
2684					}
2685				}
2686			}
2687			return nil, nil
2688		},
2689	}
2690	clientConfig := &Config{
2691		MinVersion:         VersionTLS13, // client only supports TLS 1.3
2692		CipherSuites:       ciphers,
2693		InsecureSkipVerify: true,
2694	}
2695	if _, _, err := testHandshake(t, clientConfig, serverConfig); err != nil {
2696		t.Fatalf("handshake failed: %s", err)
2697	}
2698}
2699
2700// discardConn wraps a net.Conn but discards all writes, but reports that they happened.
2701type discardConn struct {
2702	net.Conn
2703}
2704
2705func (dc *discardConn) Write(data []byte) (int, error) {
2706	return len(data), nil
2707}
2708
2709// largeRSAKeyCertPEM contains a 8193 bit RSA key
2710const largeRSAKeyCertPEM = `-----BEGIN CERTIFICATE-----
2711MIIInjCCBIWgAwIBAgIBAjANBgkqhkiG9w0BAQsFADASMRAwDgYDVQQDEwd0ZXN0
2712aW5nMB4XDTIzMDYwNzIxMjMzNloXDTIzMDYwNzIzMjMzNlowEjEQMA4GA1UEAxMH
2713dGVzdGluZzCCBCIwDQYJKoZIhvcNAQEBBQADggQPADCCBAoCggQBAWdHsf6Rh2Ca
2714n2SQwn4t4OQrOjbLLdGE1pM6TBKKrHUFy62uEL8atNjlcfXIsa4aEu3xNGiqxqur
2715ZectlkZbm0FkaaQ1Wr9oikDY3KfjuaXdPdO/XC/h8AKNxlDOylyXwUSK/CuYb+1j
2716gy8yF5QFvVfwW/xwTlHmhUeSkVSQPosfQ6yXNNsmMzkd+ZPWLrfq4R+wiNtwYGu0
2717WSBcI/M9o8/vrNLnIppoiBJJ13j9CR1ToEAzOFh9wwRWLY10oZhoh1ONN1KQURx4
2718qedzvvP2DSjZbUccdvl2rBGvZpzfOiFdm1FCnxB0c72Cqx+GTHXBFf8bsa7KHky9
2719sNO1GUanbq17WoDNgwbY6H51bfShqv0CErxatwWox3we4EcAmFHPVTCYL1oWVMGo
2720a3Eth91NZj+b/nGhF9lhHKGzXSv9brmLLkfvM1jA6XhNhA7BQ5Vz67lj2j3XfXdh
2721t/BU5pBXbL4Ut4mIhT1YnKXAjX2/LF5RHQTE8Vwkx5JAEKZyUEGOReD/B+7GOrLp
2722HduMT9vZAc5aR2k9I8qq1zBAzsL69lyQNAPaDYd1BIAjUety9gAYaSQffCgAgpRO
2723Gt+DYvxS+7AT/yEd5h74MU2AH7KrAkbXOtlwupiGwhMVTstncDJWXMJqbBhyHPF8
27243UmZH0hbL4PYmzSj9LDWQQXI2tv6vrCpfts3Cqhqxz9vRpgY7t1Wu6l/r+KxYYz3
27251pcGpPvRmPh0DJm7cPTiXqPnZcPt+ulSaSdlxmd19OnvG5awp0fXhxryZVwuiT8G
2726VDkhyARrxYrdjlINsZJZbQjO0t8ketXAELJOnbFXXzeCOosyOHkLwsqOO96AVJA8
272745ZVL5m95ClGy0RSrjVIkXsxTAMVG6SPAqKwk6vmTdRGuSPS4rhgckPVDHmccmuq
2728dfnT2YkX+wB2/M3oCgU+s30fAHGkbGZ0pCdNbFYFZLiH0iiMbTDl/0L/z7IdK0nH
2729GLHVE7apPraKC6xl6rPWsD2iSfrmtIPQa0+rqbIVvKP5JdfJ8J4alI+OxFw/znQe
2730V0/Rez0j22Fe119LZFFSXhRv+ZSvcq20xDwh00mzcumPWpYuCVPozA18yIhC9tNn
2731ALHndz0tDseIdy9vC71jQWy9iwri3ueN0DekMMF8JGzI1Z6BAFzgyAx3DkHtwHg7
2732B7qD0jPG5hJ5+yt323fYgJsuEAYoZ8/jzZ01pkX8bt+UsVN0DGnSGsI2ktnIIk3J
2733l+8krjmUy6EaW79nITwoOqaeHOIp8m3UkjEcoKOYrzHRKqRy+A09rY+m/cAQaafW
27344xp0Zv7qZPLwnu0jsqB4jD8Ll9yPB02ndsoV6U5PeHzTkVhPml19jKUAwFfs7TJg
2735kXy+/xFhYVUCAwEAATANBgkqhkiG9w0BAQsFAAOCBAIAAQnZY77pMNeypfpba2WK
2736aDasT7dk2JqP0eukJCVPTN24Zca+xJNPdzuBATm/8SdZK9lddIbjSnWRsKvTnO2r
2737/rYdlPf3jM5uuJtb8+Uwwe1s+gszelGS9G/lzzq+ehWicRIq2PFcs8o3iQMfENiv
2738qILJ+xjcrvms5ZPDNahWkfRx3KCg8Q+/at2n5p7XYjMPYiLKHnDC+RE2b1qT20IZ
2739FhuK/fTWLmKbfYFNNga6GC4qcaZJ7x0pbm4SDTYp0tkhzcHzwKhidfNB5J2vNz6l
2740Ur6wiYwamFTLqcOwWo7rdvI+sSn05WQBv0QZlzFX+OAu0l7WQ7yU+noOxBhjvHds
274114+r9qcQZg2q9kG+evopYZqYXRUNNlZKo9MRBXhfrISulFAc5lRFQIXMXnglvAu+
2742Ipz2gomEAOcOPNNVldhKAU94GAMJd/KfN0ZP7gX3YvPzuYU6XDhag5RTohXLm18w
27435AF+ES3DOQ6ixu3DTf0D+6qrDuK+prdX8ivcdTQVNOQ+MIZeGSc6NWWOTaMGJ3lg
2744aZIxJUGdo6E7GBGiC1YTjgFKFbHzek1LRTh/LX3vbSudxwaG0HQxwsU9T4DWiMqa
2745Fkf2KteLEUA6HrR+0XlAZrhwoqAmrJ+8lCFX3V0gE9lpENfVHlFXDGyx10DpTB28
2746DdjnY3F7EPWNzwf9P3oNT69CKW3Bk6VVr3ROOJtDxVu1ioWo3TaXltQ0VOnap2Pu
2747sa5wfrpfwBDuAS9JCDg4ttNp2nW3F7tgXC6xPqw5pvGwUppEw9XNrqV8TZrxduuv
2748rQ3NyZ7KSzIpmFlD3UwV/fGfz3UQmHS6Ng1evrUID9DjfYNfRqSGIGjDfxGtYD+j
2749Z1gLJZuhjJpNtwBkKRtlNtrCWCJK2hidK/foxwD7kwAPo2I9FjpltxCRywZUs07X
2750KwXTfBR9v6ij1LV6K58hFS+8ezZyZ05CeVBFkMQdclTOSfuPxlMkQOtjp8QWDj+F
2751j/MYziT5KBkHvcbrjdRtUJIAi4N7zCsPZtjik918AK1WBNRVqPbrgq/XSEXMfuvs
27526JbfK0B76vdBDRtJFC1JsvnIrGbUztxXzyQwFLaR/AjVJqpVlysLWzPKWVX6/+SJ
2753u1NQOl2E8P6ycyBsuGnO89p0S4F8cMRcI2X1XQsZ7/q0NBrOMaEp5T3SrWo9GiQ3
2754o2SBdbs3Y6MBPBtTu977Z/0RO63J3M5i2tjUiDfrFy7+VRLKr7qQ7JibohyB8QaR
27559tedgjn2f+of7PnP/PEl1cCphUZeHM7QKUMPT8dbqwmKtlYY43EHXcvNOT5IBk3X
27569lwJoZk/B2i+ZMRNSP34ztAwtxmasPt6RAWGQpWCn9qmttAHAnMfDqe7F7jVR6rS
2757u58=
2758-----END CERTIFICATE-----`
2759
2760func TestHandshakeRSATooBig(t *testing.T) {
2761	testCert, _ := pem.Decode([]byte(largeRSAKeyCertPEM))
2762
2763	c := &Conn{conn: &discardConn{}, config: testConfig.Clone()}
2764
2765	expectedErr := "tls: server sent certificate containing RSA key larger than 8192 bits"
2766	err := c.verifyServerCertificate([][]byte{testCert.Bytes})
2767	if err == nil || err.Error() != expectedErr {
2768		t.Errorf("Conn.verifyServerCertificate unexpected error: want %q, got %q", expectedErr, err)
2769	}
2770
2771	expectedErr = "tls: client sent certificate containing RSA key larger than 8192 bits"
2772	err = c.processCertsFromClient(Certificate{Certificate: [][]byte{testCert.Bytes}})
2773	if err == nil || err.Error() != expectedErr {
2774		t.Errorf("Conn.processCertsFromClient unexpected error: want %q, got %q", expectedErr, err)
2775	}
2776}
2777
2778func TestTLS13ECHRejectionCallbacks(t *testing.T) {
2779	k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
2780	if err != nil {
2781		t.Fatal(err)
2782	}
2783	tmpl := &x509.Certificate{
2784		SerialNumber: big.NewInt(1),
2785		Subject:      pkix.Name{CommonName: "test"},
2786		DNSNames:     []string{"example.golang"},
2787		NotBefore:    testConfig.Time().Add(-time.Hour),
2788		NotAfter:     testConfig.Time().Add(time.Hour),
2789	}
2790	certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, k.Public(), k)
2791	if err != nil {
2792		t.Fatal(err)
2793	}
2794	cert, err := x509.ParseCertificate(certDER)
2795	if err != nil {
2796		t.Fatal(err)
2797	}
2798
2799	clientConfig, serverConfig := testConfig.Clone(), testConfig.Clone()
2800	serverConfig.Certificates = []Certificate{
2801		{
2802			Certificate: [][]byte{certDER},
2803			PrivateKey:  k,
2804		},
2805	}
2806	serverConfig.MinVersion = VersionTLS13
2807	clientConfig.RootCAs = x509.NewCertPool()
2808	clientConfig.RootCAs.AddCert(cert)
2809	clientConfig.MinVersion = VersionTLS13
2810	clientConfig.EncryptedClientHelloConfigList, _ = hex.DecodeString("0041fe0d003d0100200020204bed0a11fc0dde595a9b78d966b0011128eb83f65d3c91c1cc5ac786cd246f000400010001ff0e6578616d706c652e676f6c616e670000")
2811	clientConfig.ServerName = "example.golang"
2812
2813	for _, tc := range []struct {
2814		name        string
2815		expectedErr string
2816
2817		verifyConnection                    func(ConnectionState) error
2818		verifyPeerCertificate               func([][]byte, [][]*x509.Certificate) error
2819		encryptedClientHelloRejectionVerify func(ConnectionState) error
2820	}{
2821		{
2822			name:        "no callbacks",
2823			expectedErr: "tls: server rejected ECH",
2824		},
2825		{
2826			name: "EncryptedClientHelloRejectionVerify, no err",
2827			encryptedClientHelloRejectionVerify: func(ConnectionState) error {
2828				return nil
2829			},
2830			expectedErr: "tls: server rejected ECH",
2831		},
2832		{
2833			name: "EncryptedClientHelloRejectionVerify, err",
2834			encryptedClientHelloRejectionVerify: func(ConnectionState) error {
2835				return errors.New("callback err")
2836			},
2837			// testHandshake returns the server side error, so we just need to
2838			// check alertBadCertificate was sent
2839			expectedErr: "callback err",
2840		},
2841		{
2842			name: "VerifyConnection, err",
2843			verifyConnection: func(ConnectionState) error {
2844				return errors.New("callback err")
2845			},
2846			expectedErr: "tls: server rejected ECH",
2847		},
2848		{
2849			name: "VerifyPeerCertificate, err",
2850			verifyPeerCertificate: func([][]byte, [][]*x509.Certificate) error {
2851				return errors.New("callback err")
2852			},
2853			expectedErr: "tls: server rejected ECH",
2854		},
2855	} {
2856		t.Run(tc.name, func(t *testing.T) {
2857			c, s := localPipe(t)
2858			done := make(chan error)
2859
2860			go func() {
2861				serverErr := Server(s, serverConfig).Handshake()
2862				s.Close()
2863				done <- serverErr
2864			}()
2865
2866			cConfig := clientConfig.Clone()
2867			cConfig.VerifyConnection = tc.verifyConnection
2868			cConfig.VerifyPeerCertificate = tc.verifyPeerCertificate
2869			cConfig.EncryptedClientHelloRejectionVerify = tc.encryptedClientHelloRejectionVerify
2870
2871			clientErr := Client(c, cConfig).Handshake()
2872			c.Close()
2873
2874			if tc.expectedErr == "" && clientErr != nil {
2875				t.Fatalf("unexpected err: %s", clientErr)
2876			} else if clientErr != nil && tc.expectedErr != clientErr.Error() {
2877				t.Fatalf("unexpected err: got %q, want %q", clientErr, tc.expectedErr)
2878			}
2879		})
2880	}
2881}
2882
2883func TestECHTLS12Server(t *testing.T) {
2884	clientConfig, serverConfig := testConfig.Clone(), testConfig.Clone()
2885
2886	serverConfig.MaxVersion = VersionTLS12
2887	clientConfig.MinVersion = 0
2888
2889	clientConfig.EncryptedClientHelloConfigList, _ = hex.DecodeString("0041fe0d003d0100200020204bed0a11fc0dde595a9b78d966b0011128eb83f65d3c91c1cc5ac786cd246f000400010001ff0e6578616d706c652e676f6c616e670000")
2890
2891	expectedErr := "server: tls: client offered only unsupported versions: [304]\nclient: remote error: tls: protocol version not supported"
2892	_, _, err := testHandshake(t, clientConfig, serverConfig)
2893	if err == nil || err.Error() != expectedErr {
2894		t.Fatalf("unexpected handshake error: got %q, want %q", err, expectedErr)
2895	}
2896}
2897