1// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package tls
6
7import (
8	"bytes"
9	"context"
10	"crypto"
11	"crypto/ecdsa"
12	"crypto/elliptic"
13	"crypto/rand"
14	"crypto/x509"
15	"crypto/x509/pkix"
16	"encoding/asn1"
17	"encoding/json"
18	"encoding/pem"
19	"errors"
20	"fmt"
21	"internal/testenv"
22	"io"
23	"math"
24	"math/big"
25	"net"
26	"os"
27	"reflect"
28	"slices"
29	"strings"
30	"testing"
31	"time"
32)
33
34var rsaCertPEM = `-----BEGIN CERTIFICATE-----
35MIIB0zCCAX2gAwIBAgIJAI/M7BYjwB+uMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV
36BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX
37aWRnaXRzIFB0eSBMdGQwHhcNMTIwOTEyMjE1MjAyWhcNMTUwOTEyMjE1MjAyWjBF
38MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50
39ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBANLJ
40hPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wok/4xIA+ui35/MmNa
41rtNuC+BdZ1tMuVCPFZcCAwEAAaNQME4wHQYDVR0OBBYEFJvKs8RfJaXTH08W+SGv
42zQyKn0H8MB8GA1UdIwQYMBaAFJvKs8RfJaXTH08W+SGvzQyKn0H8MAwGA1UdEwQF
43MAMBAf8wDQYJKoZIhvcNAQEFBQADQQBJlffJHybjDGxRMqaRmDhX0+6v02TUKZsW
44r5QuVbpQhH6u+0UgcW0jp9QwpxoPTLTWGXEWBBBurxFwiCBhkQ+V
45-----END CERTIFICATE-----
46`
47
48var rsaKeyPEM = testingKey(`-----BEGIN RSA TESTING KEY-----
49MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo
50k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G
516OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N
52MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW
53SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T
54xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi
55D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g==
56-----END RSA TESTING KEY-----
57`)
58
59// keyPEM is the same as rsaKeyPEM, but declares itself as just
60// "PRIVATE KEY", not "RSA PRIVATE KEY".  https://golang.org/issue/4477
61var keyPEM = testingKey(`-----BEGIN TESTING KEY-----
62MIIBOwIBAAJBANLJhPHhITqQbPklG3ibCVxwGMRfp/v4XqhfdQHdcVfHap6NQ5Wo
63k/4xIA+ui35/MmNartNuC+BdZ1tMuVCPFZcCAwEAAQJAEJ2N+zsR0Xn8/Q6twa4G
646OB1M1WO+k+ztnX/1SvNeWu8D6GImtupLTYgjZcHufykj09jiHmjHx8u8ZZB/o1N
65MQIhAPW+eyZo7ay3lMz1V01WVjNKK9QSn1MJlb06h/LuYv9FAiEA25WPedKgVyCW
66SmUwbPw8fnTcpqDWE3yTO3vKcebqMSsCIBF3UmVue8YU3jybC3NxuXq3wNm34R8T
67xVLHwDXh/6NJAiEAl2oHGGLz64BuAfjKrqwz7qMYr9HCLIe/YsoWq/olzScCIQDi
68D2lWusoe2/nEqfDVVWGWlyJ7yOmqaVm/iNUN9B2N2g==
69-----END TESTING KEY-----
70`)
71
72var ecdsaCertPEM = `-----BEGIN CERTIFICATE-----
73MIIB/jCCAWICCQDscdUxw16XFDAJBgcqhkjOPQQBMEUxCzAJBgNVBAYTAkFVMRMw
74EQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0
75eSBMdGQwHhcNMTIxMTE0MTI0MDQ4WhcNMTUxMTE0MTI0MDQ4WjBFMQswCQYDVQQG
76EwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lk
77Z2l0cyBQdHkgTHRkMIGbMBAGByqGSM49AgEGBSuBBAAjA4GGAAQBY9+my9OoeSUR
78lDQdV/x8LsOuLilthhiS1Tz4aGDHIPwC1mlvnf7fg5lecYpMCrLLhauAc1UJXcgl
7901xoLuzgtAEAgv2P/jgytzRSpUYvgLBt1UA0leLYBy6mQQbrNEuqT3INapKIcUv8
80XxYP0xMEUksLPq6Ca+CRSqTtrd/23uTnapkwCQYHKoZIzj0EAQOBigAwgYYCQXJo
81A7Sl2nLVf+4Iu/tAX/IF4MavARKC4PPHK3zfuGfPR3oCCcsAoz3kAzOeijvd0iXb
82H5jBImIxPL4WxQNiBTexAkF8D1EtpYuWdlVQ80/h/f4pBcGiXPqX5h2PQSQY7hP1
83+jwM1FGS4fREIOvlBYr/SzzQRtwrvrzGYxDEDbsC0ZGRnA==
84-----END CERTIFICATE-----
85`
86
87var ecdsaKeyPEM = testingKey(`-----BEGIN EC PARAMETERS-----
88BgUrgQQAIw==
89-----END EC PARAMETERS-----
90-----BEGIN EC TESTING KEY-----
91MIHcAgEBBEIBrsoKp0oqcv6/JovJJDoDVSGWdirrkgCWxrprGlzB9o0X8fV675X0
92NwuBenXFfeZvVcwluO7/Q9wkYoPd/t3jGImgBwYFK4EEACOhgYkDgYYABAFj36bL
9306h5JRGUNB1X/Hwuw64uKW2GGJLVPPhoYMcg/ALWaW+d/t+DmV5xikwKssuFq4Bz
94VQldyCXTXGgu7OC0AQCC/Y/+ODK3NFKlRi+AsG3VQDSV4tgHLqZBBus0S6pPcg1q
95kohxS/xfFg/TEwRSSws+roJr4JFKpO2t3/be5OdqmQ==
96-----END EC TESTING KEY-----
97`)
98
99var keyPairTests = []struct {
100	algo string
101	cert string
102	key  string
103}{
104	{"ECDSA", ecdsaCertPEM, ecdsaKeyPEM},
105	{"RSA", rsaCertPEM, rsaKeyPEM},
106	{"RSA-untyped", rsaCertPEM, keyPEM}, // golang.org/issue/4477
107}
108
109func TestX509KeyPair(t *testing.T) {
110	t.Parallel()
111	var pem []byte
112	for _, test := range keyPairTests {
113		pem = []byte(test.cert + test.key)
114		if _, err := X509KeyPair(pem, pem); err != nil {
115			t.Errorf("Failed to load %s cert followed by %s key: %s", test.algo, test.algo, err)
116		}
117		pem = []byte(test.key + test.cert)
118		if _, err := X509KeyPair(pem, pem); err != nil {
119			t.Errorf("Failed to load %s key followed by %s cert: %s", test.algo, test.algo, err)
120		}
121	}
122}
123
124func TestX509KeyPairErrors(t *testing.T) {
125	_, err := X509KeyPair([]byte(rsaKeyPEM), []byte(rsaCertPEM))
126	if err == nil {
127		t.Fatalf("X509KeyPair didn't return an error when arguments were switched")
128	}
129	if subStr := "been switched"; !strings.Contains(err.Error(), subStr) {
130		t.Fatalf("Expected %q in the error when switching arguments to X509KeyPair, but the error was %q", subStr, err)
131	}
132
133	_, err = X509KeyPair([]byte(rsaCertPEM), []byte(rsaCertPEM))
134	if err == nil {
135		t.Fatalf("X509KeyPair didn't return an error when both arguments were certificates")
136	}
137	if subStr := "certificate"; !strings.Contains(err.Error(), subStr) {
138		t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were certificates, but the error was %q", subStr, err)
139	}
140
141	const nonsensePEM = `
142-----BEGIN NONSENSE-----
143Zm9vZm9vZm9v
144-----END NONSENSE-----
145`
146
147	_, err = X509KeyPair([]byte(nonsensePEM), []byte(nonsensePEM))
148	if err == nil {
149		t.Fatalf("X509KeyPair didn't return an error when both arguments were nonsense")
150	}
151	if subStr := "NONSENSE"; !strings.Contains(err.Error(), subStr) {
152		t.Fatalf("Expected %q in the error when both arguments to X509KeyPair were nonsense, but the error was %q", subStr, err)
153	}
154}
155
156func TestX509MixedKeyPair(t *testing.T) {
157	if _, err := X509KeyPair([]byte(rsaCertPEM), []byte(ecdsaKeyPEM)); err == nil {
158		t.Error("Load of RSA certificate succeeded with ECDSA private key")
159	}
160	if _, err := X509KeyPair([]byte(ecdsaCertPEM), []byte(rsaKeyPEM)); err == nil {
161		t.Error("Load of ECDSA certificate succeeded with RSA private key")
162	}
163}
164
165func newLocalListener(t testing.TB) net.Listener {
166	ln, err := net.Listen("tcp", "127.0.0.1:0")
167	if err != nil {
168		ln, err = net.Listen("tcp6", "[::1]:0")
169	}
170	if err != nil {
171		t.Fatal(err)
172	}
173	return ln
174}
175
176func TestDialTimeout(t *testing.T) {
177	if testing.Short() {
178		t.Skip("skipping in short mode")
179	}
180
181	timeout := 100 * time.Microsecond
182	for !t.Failed() {
183		acceptc := make(chan net.Conn)
184		listener := newLocalListener(t)
185		go func() {
186			for {
187				conn, err := listener.Accept()
188				if err != nil {
189					close(acceptc)
190					return
191				}
192				acceptc <- conn
193			}
194		}()
195
196		addr := listener.Addr().String()
197		dialer := &net.Dialer{
198			Timeout: timeout,
199		}
200		if conn, err := DialWithDialer(dialer, "tcp", addr, nil); err == nil {
201			conn.Close()
202			t.Errorf("DialWithTimeout unexpectedly completed successfully")
203		} else if !isTimeoutError(err) {
204			t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err)
205		}
206
207		listener.Close()
208
209		// We're looking for a timeout during the handshake, so check that the
210		// Listener actually accepted the connection to initiate it. (If the server
211		// takes too long to accept the connection, we might cancel before the
212		// underlying net.Conn is ever dialed — without ever attempting a
213		// handshake.)
214		lconn, ok := <-acceptc
215		if ok {
216			// The Listener accepted a connection, so assume that it was from our
217			// Dial: we triggered the timeout at the point where we wanted it!
218			t.Logf("Listener accepted a connection from %s", lconn.RemoteAddr())
219			lconn.Close()
220		}
221		// Close any spurious extra connections from the listener. (This is
222		// possible if there are, for example, stray Dial calls from other tests.)
223		for extraConn := range acceptc {
224			t.Logf("spurious extra connection from %s", extraConn.RemoteAddr())
225			extraConn.Close()
226		}
227		if ok {
228			break
229		}
230
231		t.Logf("with timeout %v, DialWithDialer returned before listener accepted any connections; retrying", timeout)
232		timeout *= 2
233	}
234}
235
236func TestDeadlineOnWrite(t *testing.T) {
237	if testing.Short() {
238		t.Skip("skipping in short mode")
239	}
240
241	ln := newLocalListener(t)
242	defer ln.Close()
243
244	srvCh := make(chan *Conn, 1)
245
246	go func() {
247		sconn, err := ln.Accept()
248		if err != nil {
249			srvCh <- nil
250			return
251		}
252		srv := Server(sconn, testConfig.Clone())
253		if err := srv.Handshake(); err != nil {
254			srvCh <- nil
255			return
256		}
257		srvCh <- srv
258	}()
259
260	clientConfig := testConfig.Clone()
261	clientConfig.MaxVersion = VersionTLS12
262	conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
263	if err != nil {
264		t.Fatal(err)
265	}
266	defer conn.Close()
267
268	srv := <-srvCh
269	if srv == nil {
270		t.Error(err)
271	}
272
273	// Make sure the client/server is setup correctly and is able to do a typical Write/Read
274	buf := make([]byte, 6)
275	if _, err := srv.Write([]byte("foobar")); err != nil {
276		t.Errorf("Write err: %v", err)
277	}
278	if n, err := conn.Read(buf); n != 6 || err != nil || string(buf) != "foobar" {
279		t.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf)
280	}
281
282	// Set a deadline which should cause Write to timeout
283	if err = srv.SetDeadline(time.Now()); err != nil {
284		t.Fatalf("SetDeadline(time.Now()) err: %v", err)
285	}
286	if _, err = srv.Write([]byte("should fail")); err == nil {
287		t.Fatal("Write should have timed out")
288	}
289
290	// Clear deadline and make sure it still times out
291	if err = srv.SetDeadline(time.Time{}); err != nil {
292		t.Fatalf("SetDeadline(time.Time{}) err: %v", err)
293	}
294	if _, err = srv.Write([]byte("This connection is permanently broken")); err == nil {
295		t.Fatal("Write which previously failed should still time out")
296	}
297
298	// Verify the error
299	if ne := err.(net.Error); ne.Temporary() != false {
300		t.Error("Write timed out but incorrectly classified the error as Temporary")
301	}
302	if !isTimeoutError(err) {
303		t.Error("Write timed out but did not classify the error as a Timeout")
304	}
305}
306
307type readerFunc func([]byte) (int, error)
308
309func (f readerFunc) Read(b []byte) (int, error) { return f(b) }
310
311// TestDialer tests that tls.Dialer.DialContext can abort in the middle of a handshake.
312// (The other cases are all handled by the existing dial tests in this package, which
313// all also flow through the same code shared code paths)
314func TestDialer(t *testing.T) {
315	ln := newLocalListener(t)
316	defer ln.Close()
317
318	unblockServer := make(chan struct{}) // close-only
319	defer close(unblockServer)
320	go func() {
321		conn, err := ln.Accept()
322		if err != nil {
323			return
324		}
325		defer conn.Close()
326		<-unblockServer
327	}()
328
329	ctx, cancel := context.WithCancel(context.Background())
330	d := Dialer{Config: &Config{
331		Rand: readerFunc(func(b []byte) (n int, err error) {
332			// By the time crypto/tls wants randomness, that means it has a TCP
333			// connection, so we're past the Dialer's dial and now blocked
334			// in a handshake. Cancel our context and see if we get unstuck.
335			// (Our TCP listener above never reads or writes, so the Handshake
336			// would otherwise be stuck forever)
337			cancel()
338			return len(b), nil
339		}),
340		ServerName: "foo",
341	}}
342	_, err := d.DialContext(ctx, "tcp", ln.Addr().String())
343	if err != context.Canceled {
344		t.Errorf("err = %v; want context.Canceled", err)
345	}
346}
347
348func isTimeoutError(err error) bool {
349	if ne, ok := err.(net.Error); ok {
350		return ne.Timeout()
351	}
352	return false
353}
354
355// tests that Conn.Read returns (non-zero, io.EOF) instead of
356// (non-zero, nil) when a Close (alertCloseNotify) is sitting right
357// behind the application data in the buffer.
358func TestConnReadNonzeroAndEOF(t *testing.T) {
359	// This test is racy: it assumes that after a write to a
360	// localhost TCP connection, the peer TCP connection can
361	// immediately read it. Because it's racy, we skip this test
362	// in short mode, and then retry it several times with an
363	// increasing sleep in between our final write (via srv.Close
364	// below) and the following read.
365	if testing.Short() {
366		t.Skip("skipping in short mode")
367	}
368	var err error
369	for delay := time.Millisecond; delay <= 64*time.Millisecond; delay *= 2 {
370		if err = testConnReadNonzeroAndEOF(t, delay); err == nil {
371			return
372		}
373	}
374	t.Error(err)
375}
376
377func testConnReadNonzeroAndEOF(t *testing.T, delay time.Duration) error {
378	ln := newLocalListener(t)
379	defer ln.Close()
380
381	srvCh := make(chan *Conn, 1)
382	var serr error
383	go func() {
384		sconn, err := ln.Accept()
385		if err != nil {
386			serr = err
387			srvCh <- nil
388			return
389		}
390		serverConfig := testConfig.Clone()
391		srv := Server(sconn, serverConfig)
392		if err := srv.Handshake(); err != nil {
393			serr = fmt.Errorf("handshake: %v", err)
394			srvCh <- nil
395			return
396		}
397		srvCh <- srv
398	}()
399
400	clientConfig := testConfig.Clone()
401	// In TLS 1.3, alerts are encrypted and disguised as application data, so
402	// the opportunistic peek won't work.
403	clientConfig.MaxVersion = VersionTLS12
404	conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
405	if err != nil {
406		t.Fatal(err)
407	}
408	defer conn.Close()
409
410	srv := <-srvCh
411	if srv == nil {
412		return serr
413	}
414
415	buf := make([]byte, 6)
416
417	srv.Write([]byte("foobar"))
418	n, err := conn.Read(buf)
419	if n != 6 || err != nil || string(buf) != "foobar" {
420		return fmt.Errorf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf)
421	}
422
423	srv.Write([]byte("abcdef"))
424	srv.Close()
425	time.Sleep(delay)
426	n, err = conn.Read(buf)
427	if n != 6 || string(buf) != "abcdef" {
428		return fmt.Errorf("Read = %d, buf= %q; want 6, abcdef", n, buf)
429	}
430	if err != io.EOF {
431		return fmt.Errorf("Second Read error = %v; want io.EOF", err)
432	}
433	return nil
434}
435
436func TestTLSUniqueMatches(t *testing.T) {
437	ln := newLocalListener(t)
438	defer ln.Close()
439
440	serverTLSUniques := make(chan []byte)
441	parentDone := make(chan struct{})
442	childDone := make(chan struct{})
443	defer close(parentDone)
444	go func() {
445		defer close(childDone)
446		for i := 0; i < 2; i++ {
447			sconn, err := ln.Accept()
448			if err != nil {
449				t.Error(err)
450				return
451			}
452			serverConfig := testConfig.Clone()
453			serverConfig.MaxVersion = VersionTLS12 // TLSUnique is not defined in TLS 1.3
454			srv := Server(sconn, serverConfig)
455			if err := srv.Handshake(); err != nil {
456				t.Error(err)
457				return
458			}
459			select {
460			case <-parentDone:
461				return
462			case serverTLSUniques <- srv.ConnectionState().TLSUnique:
463			}
464		}
465	}()
466
467	clientConfig := testConfig.Clone()
468	clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
469	conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
470	if err != nil {
471		t.Fatal(err)
472	}
473
474	var serverTLSUniquesValue []byte
475	select {
476	case <-childDone:
477		return
478	case serverTLSUniquesValue = <-serverTLSUniques:
479	}
480
481	if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
482		t.Error("client and server channel bindings differ")
483	}
484	if serverTLSUniquesValue == nil || bytes.Equal(serverTLSUniquesValue, make([]byte, 12)) {
485		t.Error("tls-unique is empty or zero")
486	}
487	conn.Close()
488
489	conn, err = Dial("tcp", ln.Addr().String(), clientConfig)
490	if err != nil {
491		t.Fatal(err)
492	}
493	defer conn.Close()
494	if !conn.ConnectionState().DidResume {
495		t.Error("second session did not use resumption")
496	}
497
498	select {
499	case <-childDone:
500		return
501	case serverTLSUniquesValue = <-serverTLSUniques:
502	}
503
504	if !bytes.Equal(conn.ConnectionState().TLSUnique, serverTLSUniquesValue) {
505		t.Error("client and server channel bindings differ when session resumption is used")
506	}
507	if serverTLSUniquesValue == nil || bytes.Equal(serverTLSUniquesValue, make([]byte, 12)) {
508		t.Error("resumption tls-unique is empty or zero")
509	}
510}
511
512func TestVerifyHostname(t *testing.T) {
513	testenv.MustHaveExternalNetwork(t)
514
515	c, err := Dial("tcp", "www.google.com:https", nil)
516	if err != nil {
517		t.Fatal(err)
518	}
519	if err := c.VerifyHostname("www.google.com"); err != nil {
520		t.Fatalf("verify www.google.com: %v", err)
521	}
522	if err := c.VerifyHostname("www.yahoo.com"); err == nil {
523		t.Fatalf("verify www.yahoo.com succeeded")
524	}
525
526	c, err = Dial("tcp", "www.google.com:https", &Config{InsecureSkipVerify: true})
527	if err != nil {
528		t.Fatal(err)
529	}
530	if err := c.VerifyHostname("www.google.com"); err == nil {
531		t.Fatalf("verify www.google.com succeeded with InsecureSkipVerify=true")
532	}
533}
534
535func TestConnCloseBreakingWrite(t *testing.T) {
536	ln := newLocalListener(t)
537	defer ln.Close()
538
539	srvCh := make(chan *Conn, 1)
540	var serr error
541	var sconn net.Conn
542	go func() {
543		var err error
544		sconn, err = ln.Accept()
545		if err != nil {
546			serr = err
547			srvCh <- nil
548			return
549		}
550		serverConfig := testConfig.Clone()
551		srv := Server(sconn, serverConfig)
552		if err := srv.Handshake(); err != nil {
553			serr = fmt.Errorf("handshake: %v", err)
554			srvCh <- nil
555			return
556		}
557		srvCh <- srv
558	}()
559
560	cconn, err := net.Dial("tcp", ln.Addr().String())
561	if err != nil {
562		t.Fatal(err)
563	}
564	defer cconn.Close()
565
566	conn := &changeImplConn{
567		Conn: cconn,
568	}
569
570	clientConfig := testConfig.Clone()
571	tconn := Client(conn, clientConfig)
572	if err := tconn.Handshake(); err != nil {
573		t.Fatal(err)
574	}
575
576	srv := <-srvCh
577	if srv == nil {
578		t.Fatal(serr)
579	}
580	defer sconn.Close()
581
582	connClosed := make(chan struct{})
583	conn.closeFunc = func() error {
584		close(connClosed)
585		return nil
586	}
587
588	inWrite := make(chan bool, 1)
589	var errConnClosed = errors.New("conn closed for test")
590	conn.writeFunc = func(p []byte) (n int, err error) {
591		inWrite <- true
592		<-connClosed
593		return 0, errConnClosed
594	}
595
596	closeReturned := make(chan bool, 1)
597	go func() {
598		<-inWrite
599		tconn.Close() // test that this doesn't block forever.
600		closeReturned <- true
601	}()
602
603	_, err = tconn.Write([]byte("foo"))
604	if err != errConnClosed {
605		t.Errorf("Write error = %v; want errConnClosed", err)
606	}
607
608	<-closeReturned
609	if err := tconn.Close(); err != net.ErrClosed {
610		t.Errorf("Close error = %v; want net.ErrClosed", err)
611	}
612}
613
614func TestConnCloseWrite(t *testing.T) {
615	ln := newLocalListener(t)
616	defer ln.Close()
617
618	clientDoneChan := make(chan struct{})
619
620	serverCloseWrite := func() error {
621		sconn, err := ln.Accept()
622		if err != nil {
623			return fmt.Errorf("accept: %v", err)
624		}
625		defer sconn.Close()
626
627		serverConfig := testConfig.Clone()
628		srv := Server(sconn, serverConfig)
629		if err := srv.Handshake(); err != nil {
630			return fmt.Errorf("handshake: %v", err)
631		}
632		defer srv.Close()
633
634		data, err := io.ReadAll(srv)
635		if err != nil {
636			return err
637		}
638		if len(data) > 0 {
639			return fmt.Errorf("Read data = %q; want nothing", data)
640		}
641
642		if err := srv.CloseWrite(); err != nil {
643			return fmt.Errorf("server CloseWrite: %v", err)
644		}
645
646		// Wait for clientCloseWrite to finish, so we know we
647		// tested the CloseWrite before we defer the
648		// sconn.Close above, which would also cause the
649		// client to unblock like CloseWrite.
650		<-clientDoneChan
651		return nil
652	}
653
654	clientCloseWrite := func() error {
655		defer close(clientDoneChan)
656
657		clientConfig := testConfig.Clone()
658		conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
659		if err != nil {
660			return err
661		}
662		if err := conn.Handshake(); err != nil {
663			return err
664		}
665		defer conn.Close()
666
667		if err := conn.CloseWrite(); err != nil {
668			return fmt.Errorf("client CloseWrite: %v", err)
669		}
670
671		if _, err := conn.Write([]byte{0}); err != errShutdown {
672			return fmt.Errorf("CloseWrite error = %v; want errShutdown", err)
673		}
674
675		data, err := io.ReadAll(conn)
676		if err != nil {
677			return err
678		}
679		if len(data) > 0 {
680			return fmt.Errorf("Read data = %q; want nothing", data)
681		}
682		return nil
683	}
684
685	errChan := make(chan error, 2)
686
687	go func() { errChan <- serverCloseWrite() }()
688	go func() { errChan <- clientCloseWrite() }()
689
690	for i := 0; i < 2; i++ {
691		select {
692		case err := <-errChan:
693			if err != nil {
694				t.Fatal(err)
695			}
696		case <-time.After(10 * time.Second):
697			t.Fatal("deadlock")
698		}
699	}
700
701	// Also test CloseWrite being called before the handshake is
702	// finished:
703	{
704		ln2 := newLocalListener(t)
705		defer ln2.Close()
706
707		netConn, err := net.Dial("tcp", ln2.Addr().String())
708		if err != nil {
709			t.Fatal(err)
710		}
711		defer netConn.Close()
712		conn := Client(netConn, testConfig.Clone())
713
714		if err := conn.CloseWrite(); err != errEarlyCloseWrite {
715			t.Errorf("CloseWrite error = %v; want errEarlyCloseWrite", err)
716		}
717	}
718}
719
720func TestWarningAlertFlood(t *testing.T) {
721	ln := newLocalListener(t)
722	defer ln.Close()
723
724	server := func() error {
725		sconn, err := ln.Accept()
726		if err != nil {
727			return fmt.Errorf("accept: %v", err)
728		}
729		defer sconn.Close()
730
731		serverConfig := testConfig.Clone()
732		srv := Server(sconn, serverConfig)
733		if err := srv.Handshake(); err != nil {
734			return fmt.Errorf("handshake: %v", err)
735		}
736		defer srv.Close()
737
738		_, err = io.ReadAll(srv)
739		if err == nil {
740			return errors.New("unexpected lack of error from server")
741		}
742		const expected = "too many ignored"
743		if str := err.Error(); !strings.Contains(str, expected) {
744			return fmt.Errorf("expected error containing %q, but saw: %s", expected, str)
745		}
746
747		return nil
748	}
749
750	errChan := make(chan error, 1)
751	go func() { errChan <- server() }()
752
753	clientConfig := testConfig.Clone()
754	clientConfig.MaxVersion = VersionTLS12 // there are no warning alerts in TLS 1.3
755	conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
756	if err != nil {
757		t.Fatal(err)
758	}
759	defer conn.Close()
760	if err := conn.Handshake(); err != nil {
761		t.Fatal(err)
762	}
763
764	for i := 0; i < maxUselessRecords+1; i++ {
765		conn.sendAlert(alertNoRenegotiation)
766	}
767
768	if err := <-errChan; err != nil {
769		t.Fatal(err)
770	}
771}
772
773func TestCloneFuncFields(t *testing.T) {
774	const expectedCount = 9
775	called := 0
776
777	c1 := Config{
778		Time: func() time.Time {
779			called |= 1 << 0
780			return time.Time{}
781		},
782		GetCertificate: func(*ClientHelloInfo) (*Certificate, error) {
783			called |= 1 << 1
784			return nil, nil
785		},
786		GetClientCertificate: func(*CertificateRequestInfo) (*Certificate, error) {
787			called |= 1 << 2
788			return nil, nil
789		},
790		GetConfigForClient: func(*ClientHelloInfo) (*Config, error) {
791			called |= 1 << 3
792			return nil, nil
793		},
794		VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
795			called |= 1 << 4
796			return nil
797		},
798		VerifyConnection: func(ConnectionState) error {
799			called |= 1 << 5
800			return nil
801		},
802		UnwrapSession: func(identity []byte, cs ConnectionState) (*SessionState, error) {
803			called |= 1 << 6
804			return nil, nil
805		},
806		WrapSession: func(cs ConnectionState, ss *SessionState) ([]byte, error) {
807			called |= 1 << 7
808			return nil, nil
809		},
810		EncryptedClientHelloRejectionVerify: func(ConnectionState) error {
811			called |= 1 << 8
812			return nil
813		},
814	}
815
816	c2 := c1.Clone()
817
818	c2.Time()
819	c2.GetCertificate(nil)
820	c2.GetClientCertificate(nil)
821	c2.GetConfigForClient(nil)
822	c2.VerifyPeerCertificate(nil, nil)
823	c2.VerifyConnection(ConnectionState{})
824	c2.UnwrapSession(nil, ConnectionState{})
825	c2.WrapSession(ConnectionState{}, nil)
826	c2.EncryptedClientHelloRejectionVerify(ConnectionState{})
827
828	if called != (1<<expectedCount)-1 {
829		t.Fatalf("expected %d calls but saw calls %b", expectedCount, called)
830	}
831}
832
833func TestCloneNonFuncFields(t *testing.T) {
834	var c1 Config
835	v := reflect.ValueOf(&c1).Elem()
836
837	typ := v.Type()
838	for i := 0; i < typ.NumField(); i++ {
839		f := v.Field(i)
840		// testing/quick can't handle functions or interfaces and so
841		// isn't used here.
842		switch fn := typ.Field(i).Name; fn {
843		case "Rand":
844			f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
845		case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "VerifyConnection", "GetClientCertificate", "WrapSession", "UnwrapSession", "EncryptedClientHelloRejectionVerify":
846			// DeepEqual can't compare functions. If you add a
847			// function field to this list, you must also change
848			// TestCloneFuncFields to ensure that the func field is
849			// cloned.
850		case "Certificates":
851			f.Set(reflect.ValueOf([]Certificate{
852				{Certificate: [][]byte{{'b'}}},
853			}))
854		case "NameToCertificate":
855			f.Set(reflect.ValueOf(map[string]*Certificate{"a": nil}))
856		case "RootCAs", "ClientCAs":
857			f.Set(reflect.ValueOf(x509.NewCertPool()))
858		case "ClientSessionCache":
859			f.Set(reflect.ValueOf(NewLRUClientSessionCache(10)))
860		case "KeyLogWriter":
861			f.Set(reflect.ValueOf(io.Writer(os.Stdout)))
862		case "NextProtos":
863			f.Set(reflect.ValueOf([]string{"a", "b"}))
864		case "ServerName":
865			f.Set(reflect.ValueOf("b"))
866		case "ClientAuth":
867			f.Set(reflect.ValueOf(VerifyClientCertIfGiven))
868		case "InsecureSkipVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites":
869			f.Set(reflect.ValueOf(true))
870		case "MinVersion", "MaxVersion":
871			f.Set(reflect.ValueOf(uint16(VersionTLS12)))
872		case "SessionTicketKey":
873			f.Set(reflect.ValueOf([32]byte{}))
874		case "CipherSuites":
875			f.Set(reflect.ValueOf([]uint16{1, 2}))
876		case "CurvePreferences":
877			f.Set(reflect.ValueOf([]CurveID{CurveP256}))
878		case "Renegotiation":
879			f.Set(reflect.ValueOf(RenegotiateOnceAsClient))
880		case "EncryptedClientHelloConfigList":
881			f.Set(reflect.ValueOf([]byte{'x'}))
882		case "mutex", "autoSessionTicketKeys", "sessionTicketKeys":
883			continue // these are unexported fields that are handled separately
884		default:
885			t.Errorf("all fields must be accounted for, but saw unknown field %q", fn)
886		}
887	}
888	// Set the unexported fields related to session ticket keys, which are copied with Clone().
889	c1.autoSessionTicketKeys = []ticketKey{c1.ticketKeyFromBytes(c1.SessionTicketKey)}
890	c1.sessionTicketKeys = []ticketKey{c1.ticketKeyFromBytes(c1.SessionTicketKey)}
891
892	c2 := c1.Clone()
893	if !reflect.DeepEqual(&c1, c2) {
894		t.Errorf("clone failed to copy a field")
895	}
896}
897
898func TestCloneNilConfig(t *testing.T) {
899	var config *Config
900	if cc := config.Clone(); cc != nil {
901		t.Fatalf("Clone with nil should return nil, got: %+v", cc)
902	}
903}
904
905// changeImplConn is a net.Conn which can change its Write and Close
906// methods.
907type changeImplConn struct {
908	net.Conn
909	writeFunc func([]byte) (int, error)
910	closeFunc func() error
911}
912
913func (w *changeImplConn) Write(p []byte) (n int, err error) {
914	if w.writeFunc != nil {
915		return w.writeFunc(p)
916	}
917	return w.Conn.Write(p)
918}
919
920func (w *changeImplConn) Close() error {
921	if w.closeFunc != nil {
922		return w.closeFunc()
923	}
924	return w.Conn.Close()
925}
926
927func throughput(b *testing.B, version uint16, totalBytes int64, dynamicRecordSizingDisabled bool) {
928	ln := newLocalListener(b)
929	defer ln.Close()
930
931	N := b.N
932
933	// Less than 64KB because Windows appears to use a TCP rwin < 64KB.
934	// See Issue #15899.
935	const bufsize = 32 << 10
936
937	go func() {
938		buf := make([]byte, bufsize)
939		for i := 0; i < N; i++ {
940			sconn, err := ln.Accept()
941			if err != nil {
942				// panic rather than synchronize to avoid benchmark overhead
943				// (cannot call b.Fatal in goroutine)
944				panic(fmt.Errorf("accept: %v", err))
945			}
946			serverConfig := testConfig.Clone()
947			serverConfig.CipherSuites = nil // the defaults may prefer faster ciphers
948			serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
949			srv := Server(sconn, serverConfig)
950			if err := srv.Handshake(); err != nil {
951				panic(fmt.Errorf("handshake: %v", err))
952			}
953			if _, err := io.CopyBuffer(srv, srv, buf); err != nil {
954				panic(fmt.Errorf("copy buffer: %v", err))
955			}
956		}
957	}()
958
959	b.SetBytes(totalBytes)
960	clientConfig := testConfig.Clone()
961	clientConfig.CipherSuites = nil // the defaults may prefer faster ciphers
962	clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
963	clientConfig.MaxVersion = version
964
965	buf := make([]byte, bufsize)
966	chunks := int(math.Ceil(float64(totalBytes) / float64(len(buf))))
967	for i := 0; i < N; i++ {
968		conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
969		if err != nil {
970			b.Fatal(err)
971		}
972		for j := 0; j < chunks; j++ {
973			_, err := conn.Write(buf)
974			if err != nil {
975				b.Fatal(err)
976			}
977			_, err = io.ReadFull(conn, buf)
978			if err != nil {
979				b.Fatal(err)
980			}
981		}
982		conn.Close()
983	}
984}
985
986func BenchmarkThroughput(b *testing.B) {
987	for _, mode := range []string{"Max", "Dynamic"} {
988		for size := 1; size <= 64; size <<= 1 {
989			name := fmt.Sprintf("%sPacket/%dMB", mode, size)
990			b.Run(name, func(b *testing.B) {
991				b.Run("TLSv12", func(b *testing.B) {
992					throughput(b, VersionTLS12, int64(size<<20), mode == "Max")
993				})
994				b.Run("TLSv13", func(b *testing.B) {
995					throughput(b, VersionTLS13, int64(size<<20), mode == "Max")
996				})
997			})
998		}
999	}
1000}
1001
1002type slowConn struct {
1003	net.Conn
1004	bps int
1005}
1006
1007func (c *slowConn) Write(p []byte) (int, error) {
1008	if c.bps == 0 {
1009		panic("too slow")
1010	}
1011	t0 := time.Now()
1012	wrote := 0
1013	for wrote < len(p) {
1014		time.Sleep(100 * time.Microsecond)
1015		allowed := int(time.Since(t0).Seconds()*float64(c.bps)) / 8
1016		if allowed > len(p) {
1017			allowed = len(p)
1018		}
1019		if wrote < allowed {
1020			n, err := c.Conn.Write(p[wrote:allowed])
1021			wrote += n
1022			if err != nil {
1023				return wrote, err
1024			}
1025		}
1026	}
1027	return len(p), nil
1028}
1029
1030func latency(b *testing.B, version uint16, bps int, dynamicRecordSizingDisabled bool) {
1031	ln := newLocalListener(b)
1032	defer ln.Close()
1033
1034	N := b.N
1035
1036	go func() {
1037		for i := 0; i < N; i++ {
1038			sconn, err := ln.Accept()
1039			if err != nil {
1040				// panic rather than synchronize to avoid benchmark overhead
1041				// (cannot call b.Fatal in goroutine)
1042				panic(fmt.Errorf("accept: %v", err))
1043			}
1044			serverConfig := testConfig.Clone()
1045			serverConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
1046			srv := Server(&slowConn{sconn, bps}, serverConfig)
1047			if err := srv.Handshake(); err != nil {
1048				panic(fmt.Errorf("handshake: %v", err))
1049			}
1050			io.Copy(srv, srv)
1051		}
1052	}()
1053
1054	clientConfig := testConfig.Clone()
1055	clientConfig.DynamicRecordSizingDisabled = dynamicRecordSizingDisabled
1056	clientConfig.MaxVersion = version
1057
1058	buf := make([]byte, 16384)
1059	peek := make([]byte, 1)
1060
1061	for i := 0; i < N; i++ {
1062		conn, err := Dial("tcp", ln.Addr().String(), clientConfig)
1063		if err != nil {
1064			b.Fatal(err)
1065		}
1066		// make sure we're connected and previous connection has stopped
1067		if _, err := conn.Write(buf[:1]); err != nil {
1068			b.Fatal(err)
1069		}
1070		if _, err := io.ReadFull(conn, peek); err != nil {
1071			b.Fatal(err)
1072		}
1073		if _, err := conn.Write(buf); err != nil {
1074			b.Fatal(err)
1075		}
1076		if _, err = io.ReadFull(conn, peek); err != nil {
1077			b.Fatal(err)
1078		}
1079		conn.Close()
1080	}
1081}
1082
1083func BenchmarkLatency(b *testing.B) {
1084	for _, mode := range []string{"Max", "Dynamic"} {
1085		for _, kbps := range []int{200, 500, 1000, 2000, 5000} {
1086			name := fmt.Sprintf("%sPacket/%dkbps", mode, kbps)
1087			b.Run(name, func(b *testing.B) {
1088				b.Run("TLSv12", func(b *testing.B) {
1089					latency(b, VersionTLS12, kbps*1000, mode == "Max")
1090				})
1091				b.Run("TLSv13", func(b *testing.B) {
1092					latency(b, VersionTLS13, kbps*1000, mode == "Max")
1093				})
1094			})
1095		}
1096	}
1097}
1098
1099func TestConnectionStateMarshal(t *testing.T) {
1100	cs := &ConnectionState{}
1101	_, err := json.Marshal(cs)
1102	if err != nil {
1103		t.Errorf("json.Marshal failed on ConnectionState: %v", err)
1104	}
1105}
1106
1107func TestConnectionState(t *testing.T) {
1108	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
1109	if err != nil {
1110		panic(err)
1111	}
1112	rootCAs := x509.NewCertPool()
1113	rootCAs.AddCert(issuer)
1114
1115	now := func() time.Time { return time.Unix(1476984729, 0) }
1116
1117	const alpnProtocol = "golang"
1118	const serverName = "example.golang"
1119	var scts = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
1120	var ocsp = []byte("dummy ocsp")
1121
1122	for _, v := range []uint16{VersionTLS12, VersionTLS13} {
1123		var name string
1124		switch v {
1125		case VersionTLS12:
1126			name = "TLSv12"
1127		case VersionTLS13:
1128			name = "TLSv13"
1129		}
1130		t.Run(name, func(t *testing.T) {
1131			config := &Config{
1132				Time:         now,
1133				Rand:         zeroSource{},
1134				Certificates: make([]Certificate, 1),
1135				MaxVersion:   v,
1136				RootCAs:      rootCAs,
1137				ClientCAs:    rootCAs,
1138				ClientAuth:   RequireAndVerifyClientCert,
1139				NextProtos:   []string{alpnProtocol},
1140				ServerName:   serverName,
1141			}
1142			config.Certificates[0].Certificate = [][]byte{testRSACertificate}
1143			config.Certificates[0].PrivateKey = testRSAPrivateKey
1144			config.Certificates[0].SignedCertificateTimestamps = scts
1145			config.Certificates[0].OCSPStaple = ocsp
1146
1147			ss, cs, err := testHandshake(t, config, config)
1148			if err != nil {
1149				t.Fatalf("Handshake failed: %v", err)
1150			}
1151
1152			if ss.Version != v || cs.Version != v {
1153				t.Errorf("Got versions %x (server) and %x (client), expected %x", ss.Version, cs.Version, v)
1154			}
1155
1156			if !ss.HandshakeComplete || !cs.HandshakeComplete {
1157				t.Errorf("Got HandshakeComplete %v (server) and %v (client), expected true", ss.HandshakeComplete, cs.HandshakeComplete)
1158			}
1159
1160			if ss.DidResume || cs.DidResume {
1161				t.Errorf("Got DidResume %v (server) and %v (client), expected false", ss.DidResume, cs.DidResume)
1162			}
1163
1164			if ss.CipherSuite == 0 || cs.CipherSuite == 0 {
1165				t.Errorf("Got invalid cipher suite: %v (server) and %v (client)", ss.CipherSuite, cs.CipherSuite)
1166			}
1167
1168			if ss.NegotiatedProtocol != alpnProtocol || cs.NegotiatedProtocol != alpnProtocol {
1169				t.Errorf("Got negotiated protocol %q (server) and %q (client), expected %q", ss.NegotiatedProtocol, cs.NegotiatedProtocol, alpnProtocol)
1170			}
1171
1172			if !cs.NegotiatedProtocolIsMutual {
1173				t.Errorf("Got false NegotiatedProtocolIsMutual on the client side")
1174			}
1175			// NegotiatedProtocolIsMutual on the server side is unspecified.
1176
1177			if ss.ServerName != serverName {
1178				t.Errorf("Got server name %q, expected %q", ss.ServerName, serverName)
1179			}
1180			if cs.ServerName != serverName {
1181				t.Errorf("Got server name on client connection %q, expected %q", cs.ServerName, serverName)
1182			}
1183
1184			if len(ss.PeerCertificates) != 1 || len(cs.PeerCertificates) != 1 {
1185				t.Errorf("Got %d (server) and %d (client) peer certificates, expected %d", len(ss.PeerCertificates), len(cs.PeerCertificates), 1)
1186			}
1187
1188			if len(ss.VerifiedChains) != 1 || len(cs.VerifiedChains) != 1 {
1189				t.Errorf("Got %d (server) and %d (client) verified chains, expected %d", len(ss.VerifiedChains), len(cs.VerifiedChains), 1)
1190			} else if len(ss.VerifiedChains[0]) != 2 || len(cs.VerifiedChains[0]) != 2 {
1191				t.Errorf("Got %d (server) and %d (client) long verified chain, expected %d", len(ss.VerifiedChains[0]), len(cs.VerifiedChains[0]), 2)
1192			}
1193
1194			if len(cs.SignedCertificateTimestamps) != 2 {
1195				t.Errorf("Got %d SCTs, expected %d", len(cs.SignedCertificateTimestamps), 2)
1196			}
1197			if !bytes.Equal(cs.OCSPResponse, ocsp) {
1198				t.Errorf("Got OCSPs %x, expected %x", cs.OCSPResponse, ocsp)
1199			}
1200			// Only TLS 1.3 supports OCSP and SCTs on client certs.
1201			if v == VersionTLS13 {
1202				if len(ss.SignedCertificateTimestamps) != 2 {
1203					t.Errorf("Got %d client SCTs, expected %d", len(ss.SignedCertificateTimestamps), 2)
1204				}
1205				if !bytes.Equal(ss.OCSPResponse, ocsp) {
1206					t.Errorf("Got client OCSPs %x, expected %x", ss.OCSPResponse, ocsp)
1207				}
1208			}
1209
1210			if v == VersionTLS13 {
1211				if ss.TLSUnique != nil || cs.TLSUnique != nil {
1212					t.Errorf("Got TLSUnique %x (server) and %x (client), expected nil in TLS 1.3", ss.TLSUnique, cs.TLSUnique)
1213				}
1214			} else {
1215				if ss.TLSUnique == nil || cs.TLSUnique == nil {
1216					t.Errorf("Got TLSUnique %x (server) and %x (client), expected non-nil", ss.TLSUnique, cs.TLSUnique)
1217				}
1218			}
1219		})
1220	}
1221}
1222
1223// Issue 28744: Ensure that we don't modify memory
1224// that Config doesn't own such as Certificates.
1225func TestBuildNameToCertificate_doesntModifyCertificates(t *testing.T) {
1226	c0 := Certificate{
1227		Certificate: [][]byte{testRSACertificate},
1228		PrivateKey:  testRSAPrivateKey,
1229	}
1230	c1 := Certificate{
1231		Certificate: [][]byte{testSNICertificate},
1232		PrivateKey:  testRSAPrivateKey,
1233	}
1234	config := testConfig.Clone()
1235	config.Certificates = []Certificate{c0, c1}
1236
1237	config.BuildNameToCertificate()
1238	got := config.Certificates
1239	want := []Certificate{c0, c1}
1240	if !reflect.DeepEqual(got, want) {
1241		t.Fatalf("Certificates were mutated by BuildNameToCertificate\nGot: %#v\nWant: %#v\n", got, want)
1242	}
1243}
1244
1245func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
1246
1247func TestClientHelloInfo_SupportsCertificate(t *testing.T) {
1248	rsaCert := &Certificate{
1249		Certificate: [][]byte{testRSACertificate},
1250		PrivateKey:  testRSAPrivateKey,
1251	}
1252	pkcs1Cert := &Certificate{
1253		Certificate:                  [][]byte{testRSACertificate},
1254		PrivateKey:                   testRSAPrivateKey,
1255		SupportedSignatureAlgorithms: []SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256},
1256	}
1257	ecdsaCert := &Certificate{
1258		// ECDSA P-256 certificate
1259		Certificate: [][]byte{testP256Certificate},
1260		PrivateKey:  testP256PrivateKey,
1261	}
1262	ed25519Cert := &Certificate{
1263		Certificate: [][]byte{testEd25519Certificate},
1264		PrivateKey:  testEd25519PrivateKey,
1265	}
1266
1267	tests := []struct {
1268		c       *Certificate
1269		chi     *ClientHelloInfo
1270		wantErr string
1271	}{
1272		{rsaCert, &ClientHelloInfo{
1273			ServerName:        "example.golang",
1274			SignatureSchemes:  []SignatureScheme{PSSWithSHA256},
1275			SupportedVersions: []uint16{VersionTLS13},
1276		}, ""},
1277		{ecdsaCert, &ClientHelloInfo{
1278			SignatureSchemes:  []SignatureScheme{PSSWithSHA256, ECDSAWithP256AndSHA256},
1279			SupportedVersions: []uint16{VersionTLS13, VersionTLS12},
1280		}, ""},
1281		{rsaCert, &ClientHelloInfo{
1282			ServerName:        "example.com",
1283			SignatureSchemes:  []SignatureScheme{PSSWithSHA256},
1284			SupportedVersions: []uint16{VersionTLS13},
1285		}, "not valid for requested server name"},
1286		{ecdsaCert, &ClientHelloInfo{
1287			SignatureSchemes:  []SignatureScheme{ECDSAWithP384AndSHA384},
1288			SupportedVersions: []uint16{VersionTLS13},
1289		}, "signature algorithms"},
1290		{pkcs1Cert, &ClientHelloInfo{
1291			SignatureSchemes:  []SignatureScheme{PSSWithSHA256, ECDSAWithP256AndSHA256},
1292			SupportedVersions: []uint16{VersionTLS13},
1293		}, "signature algorithms"},
1294
1295		{rsaCert, &ClientHelloInfo{
1296			CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1297			SignatureSchemes:  []SignatureScheme{PKCS1WithSHA1},
1298			SupportedVersions: []uint16{VersionTLS13, VersionTLS12},
1299		}, "signature algorithms"},
1300		{rsaCert, &ClientHelloInfo{
1301			CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1302			SignatureSchemes:  []SignatureScheme{PKCS1WithSHA1},
1303			SupportedVersions: []uint16{VersionTLS13, VersionTLS12},
1304			config: &Config{
1305				CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1306				MaxVersion:   VersionTLS12,
1307			},
1308		}, ""}, // Check that mutual version selection works.
1309
1310		{ecdsaCert, &ClientHelloInfo{
1311			CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1312			SupportedCurves:   []CurveID{CurveP256},
1313			SupportedPoints:   []uint8{pointFormatUncompressed},
1314			SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1315			SupportedVersions: []uint16{VersionTLS12},
1316		}, ""},
1317		{ecdsaCert, &ClientHelloInfo{
1318			CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1319			SupportedCurves:   []CurveID{CurveP256},
1320			SupportedPoints:   []uint8{pointFormatUncompressed},
1321			SignatureSchemes:  []SignatureScheme{ECDSAWithP384AndSHA384},
1322			SupportedVersions: []uint16{VersionTLS12},
1323		}, ""}, // TLS 1.2 does not restrict curves based on the SignatureScheme.
1324		{ecdsaCert, &ClientHelloInfo{
1325			CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1326			SupportedCurves:   []CurveID{CurveP256},
1327			SupportedPoints:   []uint8{pointFormatUncompressed},
1328			SignatureSchemes:  nil,
1329			SupportedVersions: []uint16{VersionTLS12},
1330		}, ""}, // TLS 1.2 comes with default signature schemes.
1331		{ecdsaCert, &ClientHelloInfo{
1332			CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1333			SupportedCurves:   []CurveID{CurveP256},
1334			SupportedPoints:   []uint8{pointFormatUncompressed},
1335			SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1336			SupportedVersions: []uint16{VersionTLS12},
1337		}, "cipher suite"},
1338		{ecdsaCert, &ClientHelloInfo{
1339			CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1340			SupportedCurves:   []CurveID{CurveP256},
1341			SupportedPoints:   []uint8{pointFormatUncompressed},
1342			SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1343			SupportedVersions: []uint16{VersionTLS12},
1344			config: &Config{
1345				CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1346			},
1347		}, "cipher suite"},
1348		{ecdsaCert, &ClientHelloInfo{
1349			CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1350			SupportedCurves:   []CurveID{CurveP384},
1351			SupportedPoints:   []uint8{pointFormatUncompressed},
1352			SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1353			SupportedVersions: []uint16{VersionTLS12},
1354		}, "certificate curve"},
1355		{ecdsaCert, &ClientHelloInfo{
1356			CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1357			SupportedCurves:   []CurveID{CurveP256},
1358			SupportedPoints:   []uint8{1},
1359			SignatureSchemes:  []SignatureScheme{ECDSAWithP256AndSHA256},
1360			SupportedVersions: []uint16{VersionTLS12},
1361		}, "doesn't support ECDHE"},
1362		{ecdsaCert, &ClientHelloInfo{
1363			CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1364			SupportedCurves:   []CurveID{CurveP256},
1365			SupportedPoints:   []uint8{pointFormatUncompressed},
1366			SignatureSchemes:  []SignatureScheme{PSSWithSHA256},
1367			SupportedVersions: []uint16{VersionTLS12},
1368		}, "signature algorithms"},
1369
1370		{ed25519Cert, &ClientHelloInfo{
1371			CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1372			SupportedCurves:   []CurveID{CurveP256}, // only relevant for ECDHE support
1373			SupportedPoints:   []uint8{pointFormatUncompressed},
1374			SignatureSchemes:  []SignatureScheme{Ed25519},
1375			SupportedVersions: []uint16{VersionTLS12},
1376		}, ""},
1377		{ed25519Cert, &ClientHelloInfo{
1378			CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1379			SupportedCurves:   []CurveID{CurveP256}, // only relevant for ECDHE support
1380			SupportedPoints:   []uint8{pointFormatUncompressed},
1381			SignatureSchemes:  []SignatureScheme{Ed25519},
1382			SupportedVersions: []uint16{VersionTLS10},
1383			config:            &Config{MinVersion: VersionTLS10},
1384		}, "doesn't support Ed25519"},
1385		{ed25519Cert, &ClientHelloInfo{
1386			CipherSuites:      []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
1387			SupportedCurves:   []CurveID{},
1388			SupportedPoints:   []uint8{pointFormatUncompressed},
1389			SignatureSchemes:  []SignatureScheme{Ed25519},
1390			SupportedVersions: []uint16{VersionTLS12},
1391		}, "doesn't support ECDHE"},
1392
1393		{rsaCert, &ClientHelloInfo{
1394			CipherSuites:      []uint16{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA},
1395			SupportedCurves:   []CurveID{CurveP256}, // only relevant for ECDHE support
1396			SupportedPoints:   []uint8{pointFormatUncompressed},
1397			SupportedVersions: []uint16{VersionTLS10},
1398			config:            &Config{MinVersion: VersionTLS10},
1399		}, ""},
1400		{rsaCert, &ClientHelloInfo{
1401			CipherSuites:      []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1402			SupportedVersions: []uint16{VersionTLS12},
1403			config: &Config{
1404				CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1405			},
1406		}, ""}, // static RSA fallback
1407	}
1408	for i, tt := range tests {
1409		err := tt.chi.SupportsCertificate(tt.c)
1410		switch {
1411		case tt.wantErr == "" && err != nil:
1412			t.Errorf("%d: unexpected error: %v", i, err)
1413		case tt.wantErr != "" && err == nil:
1414			t.Errorf("%d: unexpected success", i)
1415		case tt.wantErr != "" && !strings.Contains(err.Error(), tt.wantErr):
1416			t.Errorf("%d: got error %q, expected %q", i, err, tt.wantErr)
1417		}
1418	}
1419}
1420
1421func TestCipherSuites(t *testing.T) {
1422	var lastID uint16
1423	for _, c := range CipherSuites() {
1424		if lastID > c.ID {
1425			t.Errorf("CipherSuites are not ordered by ID: got %#04x after %#04x", c.ID, lastID)
1426		} else {
1427			lastID = c.ID
1428		}
1429
1430		if c.Insecure {
1431			t.Errorf("%#04x: Insecure CipherSuite returned by CipherSuites()", c.ID)
1432		}
1433	}
1434	lastID = 0
1435	for _, c := range InsecureCipherSuites() {
1436		if lastID > c.ID {
1437			t.Errorf("InsecureCipherSuites are not ordered by ID: got %#04x after %#04x", c.ID, lastID)
1438		} else {
1439			lastID = c.ID
1440		}
1441
1442		if !c.Insecure {
1443			t.Errorf("%#04x: not Insecure CipherSuite returned by InsecureCipherSuites()", c.ID)
1444		}
1445	}
1446
1447	CipherSuiteByID := func(id uint16) *CipherSuite {
1448		for _, c := range CipherSuites() {
1449			if c.ID == id {
1450				return c
1451			}
1452		}
1453		for _, c := range InsecureCipherSuites() {
1454			if c.ID == id {
1455				return c
1456			}
1457		}
1458		return nil
1459	}
1460
1461	for _, c := range cipherSuites {
1462		cc := CipherSuiteByID(c.id)
1463		if cc == nil {
1464			t.Errorf("%#04x: no CipherSuite entry", c.id)
1465			continue
1466		}
1467
1468		if tls12Only := c.flags&suiteTLS12 != 0; tls12Only && len(cc.SupportedVersions) != 1 {
1469			t.Errorf("%#04x: suite is TLS 1.2 only, but SupportedVersions is %v", c.id, cc.SupportedVersions)
1470		} else if !tls12Only && len(cc.SupportedVersions) != 3 {
1471			t.Errorf("%#04x: suite TLS 1.0-1.2, but SupportedVersions is %v", c.id, cc.SupportedVersions)
1472		}
1473
1474		if cc.Insecure {
1475			if slices.Contains(defaultCipherSuites(), c.id) {
1476				t.Errorf("%#04x: insecure suite in default list", c.id)
1477			}
1478		} else {
1479			if !slices.Contains(defaultCipherSuites(), c.id) {
1480				t.Errorf("%#04x: secure suite not in default list", c.id)
1481			}
1482		}
1483
1484		if got := CipherSuiteName(c.id); got != cc.Name {
1485			t.Errorf("%#04x: unexpected CipherSuiteName: got %q, expected %q", c.id, got, cc.Name)
1486		}
1487	}
1488	for _, c := range cipherSuitesTLS13 {
1489		cc := CipherSuiteByID(c.id)
1490		if cc == nil {
1491			t.Errorf("%#04x: no CipherSuite entry", c.id)
1492			continue
1493		}
1494
1495		if cc.Insecure {
1496			t.Errorf("%#04x: Insecure %v, expected false", c.id, cc.Insecure)
1497		}
1498		if len(cc.SupportedVersions) != 1 || cc.SupportedVersions[0] != VersionTLS13 {
1499			t.Errorf("%#04x: suite is TLS 1.3 only, but SupportedVersions is %v", c.id, cc.SupportedVersions)
1500		}
1501
1502		if got := CipherSuiteName(c.id); got != cc.Name {
1503			t.Errorf("%#04x: unexpected CipherSuiteName: got %q, expected %q", c.id, got, cc.Name)
1504		}
1505	}
1506
1507	if got := CipherSuiteName(0xabc); got != "0x0ABC" {
1508		t.Errorf("unexpected fallback CipherSuiteName: got %q, expected 0x0ABC", got)
1509	}
1510
1511	if len(cipherSuitesPreferenceOrder) != len(cipherSuites) {
1512		t.Errorf("cipherSuitesPreferenceOrder is not the same size as cipherSuites")
1513	}
1514	if len(cipherSuitesPreferenceOrderNoAES) != len(cipherSuitesPreferenceOrder) {
1515		t.Errorf("cipherSuitesPreferenceOrderNoAES is not the same size as cipherSuitesPreferenceOrder")
1516	}
1517
1518	// Check that disabled suites are marked insecure.
1519	for _, badSuites := range []map[uint16]bool{disabledCipherSuites, rsaKexCiphers} {
1520		for id := range badSuites {
1521			c := CipherSuiteByID(id)
1522			if c == nil {
1523				t.Errorf("%#04x: no CipherSuite entry", id)
1524				continue
1525			}
1526			if !c.Insecure {
1527				t.Errorf("%#04x: disabled by default but not marked insecure", id)
1528			}
1529		}
1530	}
1531
1532	for i, prefOrder := range [][]uint16{cipherSuitesPreferenceOrder, cipherSuitesPreferenceOrderNoAES} {
1533		// Check that insecure and HTTP/2 bad cipher suites are at the end of
1534		// the preference lists.
1535		var sawInsecure, sawBad bool
1536		for _, id := range prefOrder {
1537			c := CipherSuiteByID(id)
1538			if c == nil {
1539				t.Errorf("%#04x: no CipherSuite entry", id)
1540				continue
1541			}
1542
1543			if c.Insecure {
1544				sawInsecure = true
1545			} else if sawInsecure {
1546				t.Errorf("%#04x: secure suite after insecure one(s)", id)
1547			}
1548
1549			if http2isBadCipher(id) {
1550				sawBad = true
1551			} else if sawBad {
1552				t.Errorf("%#04x: non-bad suite after bad HTTP/2 one(s)", id)
1553			}
1554		}
1555
1556		// Check that the list is sorted according to the documented criteria.
1557		isBetter := func(a, b uint16) int {
1558			aSuite, bSuite := cipherSuiteByID(a), cipherSuiteByID(b)
1559			aName, bName := CipherSuiteName(a), CipherSuiteName(b)
1560			// * < RC4
1561			if !strings.Contains(aName, "RC4") && strings.Contains(bName, "RC4") {
1562				return -1
1563			} else if strings.Contains(aName, "RC4") && !strings.Contains(bName, "RC4") {
1564				return +1
1565			}
1566			// * < CBC_SHA256
1567			if !strings.Contains(aName, "CBC_SHA256") && strings.Contains(bName, "CBC_SHA256") {
1568				return -1
1569			} else if strings.Contains(aName, "CBC_SHA256") && !strings.Contains(bName, "CBC_SHA256") {
1570				return +1
1571			}
1572			// * < 3DES
1573			if !strings.Contains(aName, "3DES") && strings.Contains(bName, "3DES") {
1574				return -1
1575			} else if strings.Contains(aName, "3DES") && !strings.Contains(bName, "3DES") {
1576				return +1
1577			}
1578			// ECDHE < *
1579			if aSuite.flags&suiteECDHE != 0 && bSuite.flags&suiteECDHE == 0 {
1580				return -1
1581			} else if aSuite.flags&suiteECDHE == 0 && bSuite.flags&suiteECDHE != 0 {
1582				return +1
1583			}
1584			// AEAD < CBC
1585			if aSuite.aead != nil && bSuite.aead == nil {
1586				return -1
1587			} else if aSuite.aead == nil && bSuite.aead != nil {
1588				return +1
1589			}
1590			// AES < ChaCha20
1591			if strings.Contains(aName, "AES") && strings.Contains(bName, "CHACHA20") {
1592				// negative for cipherSuitesPreferenceOrder
1593				if i == 0 {
1594					return -1
1595				} else {
1596					return +1
1597				}
1598			} else if strings.Contains(aName, "CHACHA20") && strings.Contains(bName, "AES") {
1599				// negative for cipherSuitesPreferenceOrderNoAES
1600				if i != 0 {
1601					return -1
1602				} else {
1603					return +1
1604				}
1605			}
1606			// AES-128 < AES-256
1607			if strings.Contains(aName, "AES_128") && strings.Contains(bName, "AES_256") {
1608				return -1
1609			} else if strings.Contains(aName, "AES_256") && strings.Contains(bName, "AES_128") {
1610				return +1
1611			}
1612			// ECDSA < RSA
1613			if aSuite.flags&suiteECSign != 0 && bSuite.flags&suiteECSign == 0 {
1614				return -1
1615			} else if aSuite.flags&suiteECSign == 0 && bSuite.flags&suiteECSign != 0 {
1616				return +1
1617			}
1618			t.Fatalf("two ciphersuites are equal by all criteria: %v and %v", aName, bName)
1619			panic("unreachable")
1620		}
1621		if !slices.IsSortedFunc(prefOrder, isBetter) {
1622			t.Error("preference order is not sorted according to the rules")
1623		}
1624	}
1625}
1626
1627func TestVersionName(t *testing.T) {
1628	if got, exp := VersionName(VersionTLS13), "TLS 1.3"; got != exp {
1629		t.Errorf("unexpected VersionName: got %q, expected %q", got, exp)
1630	}
1631	if got, exp := VersionName(0x12a), "0x012A"; got != exp {
1632		t.Errorf("unexpected fallback VersionName: got %q, expected %q", got, exp)
1633	}
1634}
1635
1636// http2isBadCipher is copied from net/http.
1637// TODO: if it ends up exposed somewhere, use that instead.
1638func http2isBadCipher(cipher uint16) bool {
1639	switch cipher {
1640	case TLS_RSA_WITH_RC4_128_SHA,
1641		TLS_RSA_WITH_3DES_EDE_CBC_SHA,
1642		TLS_RSA_WITH_AES_128_CBC_SHA,
1643		TLS_RSA_WITH_AES_256_CBC_SHA,
1644		TLS_RSA_WITH_AES_128_CBC_SHA256,
1645		TLS_RSA_WITH_AES_128_GCM_SHA256,
1646		TLS_RSA_WITH_AES_256_GCM_SHA384,
1647		TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
1648		TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
1649		TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
1650		TLS_ECDHE_RSA_WITH_RC4_128_SHA,
1651		TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
1652		TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
1653		TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
1654		TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
1655		TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256:
1656		return true
1657	default:
1658		return false
1659	}
1660}
1661
1662type brokenSigner struct{ crypto.Signer }
1663
1664func (s brokenSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
1665	// Replace opts with opts.HashFunc(), so rsa.PSSOptions are discarded.
1666	return s.Signer.Sign(rand, digest, opts.HashFunc())
1667}
1668
1669// TestPKCS1OnlyCert uses a client certificate with a broken crypto.Signer that
1670// always makes PKCS #1 v1.5 signatures, so can't be used with RSA-PSS.
1671func TestPKCS1OnlyCert(t *testing.T) {
1672	clientConfig := testConfig.Clone()
1673	clientConfig.Certificates = []Certificate{{
1674		Certificate: [][]byte{testRSACertificate},
1675		PrivateKey:  brokenSigner{testRSAPrivateKey},
1676	}}
1677	serverConfig := testConfig.Clone()
1678	serverConfig.MaxVersion = VersionTLS12 // TLS 1.3 doesn't support PKCS #1 v1.5
1679	serverConfig.ClientAuth = RequireAnyClientCert
1680
1681	// If RSA-PSS is selected, the handshake should fail.
1682	if _, _, err := testHandshake(t, clientConfig, serverConfig); err == nil {
1683		t.Fatal("expected broken certificate to cause connection to fail")
1684	}
1685
1686	clientConfig.Certificates[0].SupportedSignatureAlgorithms =
1687		[]SignatureScheme{PKCS1WithSHA1, PKCS1WithSHA256}
1688
1689	// But if the certificate restricts supported algorithms, RSA-PSS should not
1690	// be selected, and the handshake should succeed.
1691	if _, _, err := testHandshake(t, clientConfig, serverConfig); err != nil {
1692		t.Error(err)
1693	}
1694}
1695
1696func TestVerifyCertificates(t *testing.T) {
1697	// See https://go.dev/issue/31641.
1698	t.Run("TLSv12", func(t *testing.T) { testVerifyCertificates(t, VersionTLS12) })
1699	t.Run("TLSv13", func(t *testing.T) { testVerifyCertificates(t, VersionTLS13) })
1700}
1701
1702func testVerifyCertificates(t *testing.T, version uint16) {
1703	tests := []struct {
1704		name string
1705
1706		InsecureSkipVerify bool
1707		ClientAuth         ClientAuthType
1708		ClientCertificates bool
1709	}{
1710		{
1711			name: "defaults",
1712		},
1713		{
1714			name:               "InsecureSkipVerify",
1715			InsecureSkipVerify: true,
1716		},
1717		{
1718			name:       "RequestClientCert with no certs",
1719			ClientAuth: RequestClientCert,
1720		},
1721		{
1722			name:               "RequestClientCert with certs",
1723			ClientAuth:         RequestClientCert,
1724			ClientCertificates: true,
1725		},
1726		{
1727			name:               "RequireAnyClientCert",
1728			ClientAuth:         RequireAnyClientCert,
1729			ClientCertificates: true,
1730		},
1731		{
1732			name:       "VerifyClientCertIfGiven with no certs",
1733			ClientAuth: VerifyClientCertIfGiven,
1734		},
1735		{
1736			name:               "VerifyClientCertIfGiven with certs",
1737			ClientAuth:         VerifyClientCertIfGiven,
1738			ClientCertificates: true,
1739		},
1740		{
1741			name:               "RequireAndVerifyClientCert",
1742			ClientAuth:         RequireAndVerifyClientCert,
1743			ClientCertificates: true,
1744		},
1745	}
1746
1747	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
1748	if err != nil {
1749		t.Fatal(err)
1750	}
1751	rootCAs := x509.NewCertPool()
1752	rootCAs.AddCert(issuer)
1753
1754	for _, test := range tests {
1755		test := test
1756		t.Run(test.name, func(t *testing.T) {
1757			t.Parallel()
1758
1759			var serverVerifyConnection, clientVerifyConnection bool
1760			var serverVerifyPeerCertificates, clientVerifyPeerCertificates bool
1761
1762			clientConfig := testConfig.Clone()
1763			clientConfig.Time = func() time.Time { return time.Unix(1476984729, 0) }
1764			clientConfig.MaxVersion = version
1765			clientConfig.MinVersion = version
1766			clientConfig.RootCAs = rootCAs
1767			clientConfig.ServerName = "example.golang"
1768			clientConfig.ClientSessionCache = NewLRUClientSessionCache(1)
1769			serverConfig := clientConfig.Clone()
1770			serverConfig.ClientCAs = rootCAs
1771
1772			clientConfig.VerifyConnection = func(cs ConnectionState) error {
1773				clientVerifyConnection = true
1774				return nil
1775			}
1776			clientConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
1777				clientVerifyPeerCertificates = true
1778				return nil
1779			}
1780			serverConfig.VerifyConnection = func(cs ConnectionState) error {
1781				serverVerifyConnection = true
1782				return nil
1783			}
1784			serverConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
1785				serverVerifyPeerCertificates = true
1786				return nil
1787			}
1788
1789			clientConfig.InsecureSkipVerify = test.InsecureSkipVerify
1790			serverConfig.ClientAuth = test.ClientAuth
1791			if !test.ClientCertificates {
1792				clientConfig.Certificates = nil
1793			}
1794
1795			if _, _, err := testHandshake(t, clientConfig, serverConfig); err != nil {
1796				t.Fatal(err)
1797			}
1798
1799			want := serverConfig.ClientAuth != NoClientCert
1800			if serverVerifyPeerCertificates != want {
1801				t.Errorf("VerifyPeerCertificates on the server: got %v, want %v",
1802					serverVerifyPeerCertificates, want)
1803			}
1804			if !clientVerifyPeerCertificates {
1805				t.Errorf("VerifyPeerCertificates not called on the client")
1806			}
1807			if !serverVerifyConnection {
1808				t.Error("VerifyConnection did not get called on the server")
1809			}
1810			if !clientVerifyConnection {
1811				t.Error("VerifyConnection did not get called on the client")
1812			}
1813
1814			serverVerifyPeerCertificates, clientVerifyPeerCertificates = false, false
1815			serverVerifyConnection, clientVerifyConnection = false, false
1816			cs, _, err := testHandshake(t, clientConfig, serverConfig)
1817			if err != nil {
1818				t.Fatal(err)
1819			}
1820			if !cs.DidResume {
1821				t.Error("expected resumption")
1822			}
1823
1824			if serverVerifyPeerCertificates {
1825				t.Error("VerifyPeerCertificates got called on the server on resumption")
1826			}
1827			if clientVerifyPeerCertificates {
1828				t.Error("VerifyPeerCertificates got called on the client on resumption")
1829			}
1830			if !serverVerifyConnection {
1831				t.Error("VerifyConnection did not get called on the server on resumption")
1832			}
1833			if !clientVerifyConnection {
1834				t.Error("VerifyConnection did not get called on the client on resumption")
1835			}
1836		})
1837	}
1838}
1839
1840func TestHandshakeKyber(t *testing.T) {
1841	if x25519Kyber768Draft00.String() != "X25519Kyber768Draft00" {
1842		t.Fatalf("unexpected CurveID string: %v", x25519Kyber768Draft00.String())
1843	}
1844
1845	var tests = []struct {
1846		name                string
1847		clientConfig        func(*Config)
1848		serverConfig        func(*Config)
1849		preparation         func(*testing.T)
1850		expectClientSupport bool
1851		expectKyber         bool
1852		expectHRR           bool
1853	}{
1854		{
1855			name:                "Default",
1856			expectClientSupport: true,
1857			expectKyber:         true,
1858			expectHRR:           false,
1859		},
1860		{
1861			name: "ClientCurvePreferences",
1862			clientConfig: func(config *Config) {
1863				config.CurvePreferences = []CurveID{X25519}
1864			},
1865			expectClientSupport: false,
1866		},
1867		{
1868			name: "ServerCurvePreferencesX25519",
1869			serverConfig: func(config *Config) {
1870				config.CurvePreferences = []CurveID{X25519}
1871			},
1872			expectClientSupport: true,
1873			expectKyber:         false,
1874			expectHRR:           false,
1875		},
1876		{
1877			name: "ServerCurvePreferencesHRR",
1878			serverConfig: func(config *Config) {
1879				config.CurvePreferences = []CurveID{CurveP256}
1880			},
1881			expectClientSupport: true,
1882			expectKyber:         false,
1883			expectHRR:           true,
1884		},
1885		{
1886			name: "ClientTLSv12",
1887			clientConfig: func(config *Config) {
1888				config.MaxVersion = VersionTLS12
1889			},
1890			expectClientSupport: false,
1891		},
1892		{
1893			name: "ServerTLSv12",
1894			serverConfig: func(config *Config) {
1895				config.MaxVersion = VersionTLS12
1896			},
1897			expectClientSupport: true,
1898			expectKyber:         false,
1899		},
1900		{
1901			name: "GODEBUG",
1902			preparation: func(t *testing.T) {
1903				t.Setenv("GODEBUG", "tlskyber=0")
1904			},
1905			expectClientSupport: false,
1906		},
1907	}
1908
1909	baseConfig := testConfig.Clone()
1910	baseConfig.CurvePreferences = nil
1911	for _, test := range tests {
1912		t.Run(test.name, func(t *testing.T) {
1913			if test.preparation != nil {
1914				test.preparation(t)
1915			} else {
1916				t.Parallel()
1917			}
1918			serverConfig := baseConfig.Clone()
1919			if test.serverConfig != nil {
1920				test.serverConfig(serverConfig)
1921			}
1922			serverConfig.GetConfigForClient = func(hello *ClientHelloInfo) (*Config, error) {
1923				if !test.expectClientSupport && slices.Contains(hello.SupportedCurves, x25519Kyber768Draft00) {
1924					return nil, errors.New("client supports Kyber768Draft00")
1925				} else if test.expectClientSupport && !slices.Contains(hello.SupportedCurves, x25519Kyber768Draft00) {
1926					return nil, errors.New("client does not support Kyber768Draft00")
1927				}
1928				return nil, nil
1929			}
1930			clientConfig := baseConfig.Clone()
1931			if test.clientConfig != nil {
1932				test.clientConfig(clientConfig)
1933			}
1934			ss, cs, err := testHandshake(t, clientConfig, serverConfig)
1935			if err != nil {
1936				t.Fatal(err)
1937			}
1938			if test.expectKyber {
1939				if ss.testingOnlyCurveID != x25519Kyber768Draft00 {
1940					t.Errorf("got CurveID %v (server), expected %v", ss.testingOnlyCurveID, x25519Kyber768Draft00)
1941				}
1942				if cs.testingOnlyCurveID != x25519Kyber768Draft00 {
1943					t.Errorf("got CurveID %v (client), expected %v", cs.testingOnlyCurveID, x25519Kyber768Draft00)
1944				}
1945			} else {
1946				if ss.testingOnlyCurveID == x25519Kyber768Draft00 {
1947					t.Errorf("got CurveID %v (server), expected not Kyber", ss.testingOnlyCurveID)
1948				}
1949				if cs.testingOnlyCurveID == x25519Kyber768Draft00 {
1950					t.Errorf("got CurveID %v (client), expected not Kyber", cs.testingOnlyCurveID)
1951				}
1952			}
1953			if test.expectHRR {
1954				if !ss.testingOnlyDidHRR {
1955					t.Error("server did not use HRR")
1956				}
1957				if !cs.testingOnlyDidHRR {
1958					t.Error("client did not use HRR")
1959				}
1960			} else {
1961				if ss.testingOnlyDidHRR {
1962					t.Error("server used HRR")
1963				}
1964				if cs.testingOnlyDidHRR {
1965					t.Error("client used HRR")
1966				}
1967			}
1968		})
1969	}
1970}
1971
1972func TestX509KeyPairPopulateCertificate(t *testing.T) {
1973	key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
1974	if err != nil {
1975		t.Fatal(err)
1976	}
1977	keyDER, err := x509.MarshalPKCS8PrivateKey(key)
1978	if err != nil {
1979		t.Fatal(err)
1980	}
1981	keyPEM := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: keyDER})
1982	tmpl := &x509.Certificate{
1983		SerialNumber: big.NewInt(1),
1984		Subject:      pkix.Name{CommonName: "test"},
1985	}
1986	certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key.Public(), key)
1987	if err != nil {
1988		t.Fatal(err)
1989	}
1990	certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
1991
1992	t.Run("x509keypairleaf=0", func(t *testing.T) {
1993		t.Setenv("GODEBUG", "x509keypairleaf=0")
1994		cert, err := X509KeyPair(certPEM, keyPEM)
1995		if err != nil {
1996			t.Fatal(err)
1997		}
1998		if cert.Leaf != nil {
1999			t.Fatal("Leaf should not be populated")
2000		}
2001	})
2002	t.Run("x509keypairleaf=1", func(t *testing.T) {
2003		t.Setenv("GODEBUG", "x509keypairleaf=1")
2004		cert, err := X509KeyPair(certPEM, keyPEM)
2005		if err != nil {
2006			t.Fatal(err)
2007		}
2008		if cert.Leaf == nil {
2009			t.Fatal("Leaf should be populated")
2010		}
2011	})
2012	t.Run("GODEBUG unset", func(t *testing.T) {
2013		cert, err := X509KeyPair(certPEM, keyPEM)
2014		if err != nil {
2015			t.Fatal(err)
2016		}
2017		if cert.Leaf == nil {
2018			t.Fatal("Leaf should be populated")
2019		}
2020	})
2021}
2022
2023func TestEarlyLargeCertMsg(t *testing.T) {
2024	client, server := localPipe(t)
2025
2026	go func() {
2027		if _, err := client.Write([]byte{byte(recordTypeHandshake), 3, 4, 0, 4, typeCertificate, 1, 255, 255}); err != nil {
2028			t.Log(err)
2029		}
2030	}()
2031
2032	expectedErr := "tls: handshake message of length 131071 bytes exceeds maximum of 65536 bytes"
2033	servConn := Server(server, testConfig)
2034	err := servConn.Handshake()
2035	if err == nil {
2036		t.Fatal("unexpected success")
2037	}
2038	if err.Error() != expectedErr {
2039		t.Fatalf("unexpected error: got %q, want %q", err, expectedErr)
2040	}
2041}
2042
2043func TestLargeCertMsg(t *testing.T) {
2044	k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
2045	if err != nil {
2046		t.Fatal(err)
2047	}
2048	tmpl := &x509.Certificate{
2049		SerialNumber: big.NewInt(1),
2050		Subject:      pkix.Name{CommonName: "test"},
2051		ExtraExtensions: []pkix.Extension{
2052			{
2053				Id: asn1.ObjectIdentifier{1, 2, 3},
2054				// Ballast to inflate the certificate beyond the
2055				// regular handshake record size.
2056				Value: make([]byte, 65536),
2057			},
2058		},
2059	}
2060	cert, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, k.Public(), k)
2061	if err != nil {
2062		t.Fatal(err)
2063	}
2064
2065	clientConfig, serverConfig := testConfig.Clone(), testConfig.Clone()
2066	clientConfig.InsecureSkipVerify = true
2067	serverConfig.Certificates = []Certificate{
2068		{
2069			Certificate: [][]byte{cert},
2070			PrivateKey:  k,
2071		},
2072	}
2073	if _, _, err := testHandshake(t, clientConfig, serverConfig); err != nil {
2074		t.Fatalf("unexpected failure :%s", err)
2075	}
2076}
2077