1package tls
2
3import (
4	"bytes"
5	"crypto/x509"
6	"encoding/base64"
7	"encoding/json"
8	"encoding/pem"
9	"flag"
10	"fmt"
11	"internal/byteorder"
12	"internal/testenv"
13	"io"
14	"log"
15	"net"
16	"os"
17	"os/exec"
18	"path/filepath"
19	"runtime"
20	"strconv"
21	"strings"
22	"testing"
23)
24
25var (
26	port   = flag.String("port", "", "")
27	server = flag.Bool("server", false, "")
28
29	isHandshakerSupported = flag.Bool("is-handshaker-supported", false, "")
30
31	keyfile  = flag.String("key-file", "", "")
32	certfile = flag.String("cert-file", "", "")
33
34	trustCert = flag.String("trust-cert", "", "")
35
36	minVersion    = flag.Int("min-version", VersionSSL30, "")
37	maxVersion    = flag.Int("max-version", VersionTLS13, "")
38	expectVersion = flag.Int("expect-version", 0, "")
39
40	noTLS1  = flag.Bool("no-tls1", false, "")
41	noTLS11 = flag.Bool("no-tls11", false, "")
42	noTLS12 = flag.Bool("no-tls12", false, "")
43	noTLS13 = flag.Bool("no-tls13", false, "")
44
45	requireAnyClientCertificate = flag.Bool("require-any-client-certificate", false, "")
46
47	shimWritesFirst = flag.Bool("shim-writes-first", false, "")
48
49	resumeCount = flag.Int("resume-count", 0, "")
50
51	curves        = flagStringSlice("curves", "")
52	expectedCurve = flag.String("expect-curve-id", "", "")
53
54	shimID = flag.Uint64("shim-id", 0, "")
55	_      = flag.Bool("ipv6", false, "")
56
57	echConfigListB64           = flag.String("ech-config-list", "", "")
58	expectECHAccepted          = flag.Bool("expect-ech-accept", false, "")
59	expectHRR                  = flag.Bool("expect-hrr", false, "")
60	expectNoHRR                = flag.Bool("expect-no-hrr", false, "")
61	expectedECHRetryConfigs    = flag.String("expect-ech-retry-configs", "", "")
62	expectNoECHRetryConfigs    = flag.Bool("expect-no-ech-retry-configs", false, "")
63	onInitialExpectECHAccepted = flag.Bool("on-initial-expect-ech-accept", false, "")
64	_                          = flag.Bool("expect-no-ech-name-override", false, "")
65	_                          = flag.String("expect-ech-name-override", "", "")
66	_                          = flag.Bool("reverify-on-resume", false, "")
67	onResumeECHConfigListB64   = flag.String("on-resume-ech-config-list", "", "")
68	_                          = flag.Bool("on-resume-expect-reject-early-data", false, "")
69	onResumeExpectECHAccepted  = flag.Bool("on-resume-expect-ech-accept", false, "")
70	_                          = flag.Bool("on-resume-expect-no-ech-name-override", false, "")
71	expectedServerName         = flag.String("expect-server-name", "", "")
72
73	expectSessionMiss = flag.Bool("expect-session-miss", false, "")
74
75	_                       = flag.Bool("enable-early-data", false, "")
76	_                       = flag.Bool("on-resume-expect-accept-early-data", false, "")
77	_                       = flag.Bool("expect-ticket-supports-early-data", false, "")
78	onResumeShimWritesFirst = flag.Bool("on-resume-shim-writes-first", false, "")
79
80	advertiseALPN = flag.String("advertise-alpn", "", "")
81	expectALPN    = flag.String("expect-alpn", "", "")
82	rejectALPN    = flag.Bool("reject-alpn", false, "")
83	declineALPN   = flag.Bool("decline-alpn", false, "")
84
85	hostName = flag.String("host-name", "", "")
86
87	verifyPeer = flag.Bool("verify-peer", false, "")
88	_          = flag.Bool("use-custom-verify-callback", false, "")
89)
90
91type stringSlice []string
92
93func flagStringSlice(name, usage string) *stringSlice {
94	f := &stringSlice{}
95	flag.Var(f, name, usage)
96	return f
97}
98
99func (saf stringSlice) String() string {
100	return strings.Join(saf, ",")
101}
102
103func (saf stringSlice) Set(s string) error {
104	saf = append(saf, s)
105	return nil
106}
107
108func bogoShim() {
109	if *isHandshakerSupported {
110		fmt.Println("No")
111		return
112	}
113
114	cfg := &Config{
115		ServerName: "test",
116
117		MinVersion: uint16(*minVersion),
118		MaxVersion: uint16(*maxVersion),
119
120		ClientSessionCache: NewLRUClientSessionCache(0),
121	}
122
123	if *noTLS1 {
124		cfg.MinVersion = VersionTLS11
125		if *noTLS11 {
126			cfg.MinVersion = VersionTLS12
127			if *noTLS12 {
128				cfg.MinVersion = VersionTLS13
129				if *noTLS13 {
130					log.Fatalf("no supported versions enabled")
131				}
132			}
133		}
134	} else if *noTLS13 {
135		cfg.MaxVersion = VersionTLS12
136		if *noTLS12 {
137			cfg.MaxVersion = VersionTLS11
138			if *noTLS11 {
139				cfg.MaxVersion = VersionTLS10
140				if *noTLS1 {
141					log.Fatalf("no supported versions enabled")
142				}
143			}
144		}
145	}
146
147	if *advertiseALPN != "" {
148		alpns := *advertiseALPN
149		for len(alpns) > 0 {
150			alpnLen := int(alpns[0])
151			cfg.NextProtos = append(cfg.NextProtos, alpns[1:1+alpnLen])
152			alpns = alpns[alpnLen+1:]
153		}
154	}
155
156	if *rejectALPN {
157		cfg.NextProtos = []string{"unnegotiableprotocol"}
158	}
159
160	if *declineALPN {
161		cfg.NextProtos = []string{}
162	}
163
164	if *hostName != "" {
165		cfg.ServerName = *hostName
166	}
167
168	if *keyfile != "" || *certfile != "" {
169		pair, err := LoadX509KeyPair(*certfile, *keyfile)
170		if err != nil {
171			log.Fatalf("load key-file err: %s", err)
172		}
173		cfg.Certificates = []Certificate{pair}
174	}
175	if *trustCert != "" {
176		pool := x509.NewCertPool()
177		certFile, err := os.ReadFile(*trustCert)
178		if err != nil {
179			log.Fatalf("load trust-cert err: %s", err)
180		}
181		block, _ := pem.Decode(certFile)
182		cert, err := x509.ParseCertificate(block.Bytes)
183		if err != nil {
184			log.Fatalf("parse trust-cert err: %s", err)
185		}
186		pool.AddCert(cert)
187		cfg.RootCAs = pool
188	}
189
190	if *requireAnyClientCertificate {
191		cfg.ClientAuth = RequireAnyClientCert
192	}
193	if *verifyPeer {
194		cfg.ClientAuth = VerifyClientCertIfGiven
195	}
196
197	if *echConfigListB64 != "" {
198		echConfigList, err := base64.StdEncoding.DecodeString(*echConfigListB64)
199		if err != nil {
200			log.Fatalf("parse ech-config-list err: %s", err)
201		}
202		cfg.EncryptedClientHelloConfigList = echConfigList
203		cfg.MinVersion = VersionTLS13
204	}
205
206	if len(*curves) != 0 {
207		for _, curveStr := range *curves {
208			id, err := strconv.Atoi(curveStr)
209			if err != nil {
210				log.Fatalf("failed to parse curve id %q: %s", curveStr, err)
211			}
212			cfg.CurvePreferences = append(cfg.CurvePreferences, CurveID(id))
213		}
214	}
215
216	for i := 0; i < *resumeCount+1; i++ {
217		if i > 0 && (*onResumeECHConfigListB64 != "") {
218			echConfigList, err := base64.StdEncoding.DecodeString(*onResumeECHConfigListB64)
219			if err != nil {
220				log.Fatalf("parse ech-config-list err: %s", err)
221			}
222			cfg.EncryptedClientHelloConfigList = echConfigList
223		}
224
225		conn, err := net.Dial("tcp", net.JoinHostPort("localhost", *port))
226		if err != nil {
227			log.Fatalf("dial err: %s", err)
228		}
229		defer conn.Close()
230
231		// Write the shim ID we were passed as a little endian uint64
232		shimIDBytes := make([]byte, 8)
233		byteorder.LePutUint64(shimIDBytes, *shimID)
234		if _, err := conn.Write(shimIDBytes); err != nil {
235			log.Fatalf("failed to write shim id: %s", err)
236		}
237
238		var tlsConn *Conn
239		if *server {
240			tlsConn = Server(conn, cfg)
241		} else {
242			tlsConn = Client(conn, cfg)
243		}
244
245		if i == 0 && *shimWritesFirst {
246			if _, err := tlsConn.Write([]byte("hello")); err != nil {
247				log.Fatalf("write err: %s", err)
248			}
249		}
250
251		for {
252			buf := make([]byte, 500)
253			var n int
254			n, err = tlsConn.Read(buf)
255			if err != nil {
256				break
257			}
258			buf = buf[:n]
259			for i := range buf {
260				buf[i] ^= 0xff
261			}
262			if _, err = tlsConn.Write(buf); err != nil {
263				break
264			}
265		}
266		if err != nil && err != io.EOF {
267			retryErr, ok := err.(*ECHRejectionError)
268			if !ok {
269				log.Fatalf("unexpected error type returned: %v", err)
270			}
271			if *expectNoECHRetryConfigs && len(retryErr.RetryConfigList) > 0 {
272				log.Fatalf("expected no ECH retry configs, got some")
273			}
274			if *expectedECHRetryConfigs != "" {
275				expectedRetryConfigs, err := base64.StdEncoding.DecodeString(*expectedECHRetryConfigs)
276				if err != nil {
277					log.Fatalf("failed to decode expected retry configs: %s", err)
278				}
279				if !bytes.Equal(retryErr.RetryConfigList, expectedRetryConfigs) {
280					log.Fatalf("unexpected retry list returned: got %x, want %x", retryErr.RetryConfigList, expectedRetryConfigs)
281				}
282			}
283			log.Fatalf("conn error: %s", err)
284		}
285
286		cs := tlsConn.ConnectionState()
287		if cs.HandshakeComplete {
288			if *expectALPN != "" && cs.NegotiatedProtocol != *expectALPN {
289				log.Fatalf("unexpected protocol negotiated: want %q, got %q", *expectALPN, cs.NegotiatedProtocol)
290			}
291			if *expectVersion != 0 && cs.Version != uint16(*expectVersion) {
292				log.Fatalf("expected ssl version %q, got %q", uint16(*expectVersion), cs.Version)
293			}
294			if *declineALPN && cs.NegotiatedProtocol != "" {
295				log.Fatal("unexpected ALPN protocol")
296			}
297			if *expectECHAccepted && !cs.ECHAccepted {
298				log.Fatal("expected ECH to be accepted, but connection state shows it was not")
299			} else if i == 0 && *onInitialExpectECHAccepted && !cs.ECHAccepted {
300				log.Fatal("expected ECH to be accepted, but connection state shows it was not")
301			} else if i > 0 && *onResumeExpectECHAccepted && !cs.ECHAccepted {
302				log.Fatal("expected ECH to be accepted on resumption, but connection state shows it was not")
303			} else if i == 0 && !*expectECHAccepted && cs.ECHAccepted {
304				log.Fatal("did not expect ECH, but it was accepted")
305			}
306
307			if *expectHRR && !cs.testingOnlyDidHRR {
308				log.Fatal("expected HRR but did not do it")
309			}
310
311			if *expectNoHRR && cs.testingOnlyDidHRR {
312				log.Fatal("expected no HRR but did do it")
313			}
314
315			if *expectSessionMiss && cs.DidResume {
316				log.Fatal("unexpected session resumption")
317			}
318
319			if *expectedServerName != "" && cs.ServerName != *expectedServerName {
320				log.Fatalf("unexpected server name: got %q, want %q", cs.ServerName, *expectedServerName)
321			}
322		}
323
324		if *expectedCurve != "" {
325			expectedCurveID, err := strconv.Atoi(*expectedCurve)
326			if err != nil {
327				log.Fatalf("failed to parse -expect-curve-id: %s", err)
328			}
329			if tlsConn.curveID != CurveID(expectedCurveID) {
330				log.Fatalf("unexpected curve id: want %d, got %d", expectedCurveID, tlsConn.curveID)
331			}
332		}
333	}
334}
335
336func TestBogoSuite(t *testing.T) {
337	testenv.SkipIfShortAndSlow(t)
338	testenv.MustHaveExternalNetwork(t)
339	testenv.MustHaveGoRun(t)
340	testenv.MustHaveExec(t)
341
342	if testing.Short() {
343		t.Skip("skipping in short mode")
344	}
345	if testenv.Builder() != "" && runtime.GOOS == "windows" {
346		t.Skip("#66913: windows network connections are flakey on builders")
347	}
348
349	// In order to make Go test caching work as expected, we stat the
350	// bogo_config.json file, so that the Go testing hooks know that it is
351	// important for this test and will invalidate a cached test result if the
352	// file changes.
353	if _, err := os.Stat("bogo_config.json"); err != nil {
354		t.Fatal(err)
355	}
356
357	var bogoDir string
358	if *bogoLocalDir != "" {
359		bogoDir = *bogoLocalDir
360	} else {
361		const boringsslModVer = "v0.0.0-20240523173554-273a920f84e8"
362		output, err := exec.Command("go", "mod", "download", "-json", "boringssl.googlesource.com/boringssl.git@"+boringsslModVer).CombinedOutput()
363		if err != nil {
364			t.Fatalf("failed to download boringssl: %s", err)
365		}
366		var j struct {
367			Dir string
368		}
369		if err := json.Unmarshal(output, &j); err != nil {
370			t.Fatalf("failed to parse 'go mod download' output: %s", err)
371		}
372		bogoDir = j.Dir
373	}
374
375	cwd, err := os.Getwd()
376	if err != nil {
377		t.Fatal(err)
378	}
379
380	resultsFile := filepath.Join(t.TempDir(), "results.json")
381
382	args := []string{
383		"test",
384		".",
385		fmt.Sprintf("-shim-config=%s", filepath.Join(cwd, "bogo_config.json")),
386		fmt.Sprintf("-shim-path=%s", os.Args[0]),
387		"-shim-extra-flags=-bogo-mode",
388		"-allow-unimplemented",
389		"-loose-errors", // TODO(roland): this should be removed eventually
390		fmt.Sprintf("-json-output=%s", resultsFile),
391	}
392	if *bogoFilter != "" {
393		args = append(args, fmt.Sprintf("-test=%s", *bogoFilter))
394	}
395
396	goCmd, err := testenv.GoTool()
397	if err != nil {
398		t.Fatal(err)
399	}
400	cmd := exec.Command(goCmd, args...)
401	out := &strings.Builder{}
402	cmd.Stderr = out
403	cmd.Dir = filepath.Join(bogoDir, "ssl/test/runner")
404	err = cmd.Run()
405	// NOTE: we don't immediately check the error, because the failure could be either because
406	// the runner failed for some unexpected reason, or because a test case failed, and we
407	// cannot easily differentiate these cases. We check if the JSON results file was written,
408	// which should only happen if the failure was because of a test failure, and use that
409	// to determine the failure mode.
410
411	resultsJSON, jsonErr := os.ReadFile(resultsFile)
412	if jsonErr != nil {
413		if err != nil {
414			t.Fatalf("bogo failed: %s\n%s", err, out)
415		}
416		t.Fatalf("failed to read results JSON file: %s", jsonErr)
417	}
418
419	var results bogoResults
420	if err := json.Unmarshal(resultsJSON, &results); err != nil {
421		t.Fatalf("failed to parse results JSON: %s", err)
422	}
423
424	// assertResults contains test results we want to make sure
425	// are present in the output. They are only checked if -bogo-filter
426	// was not passed.
427	assertResults := map[string]string{
428		"CurveTest-Client-Kyber-TLS13": "PASS",
429		"CurveTest-Server-Kyber-TLS13": "PASS",
430	}
431
432	for name, result := range results.Tests {
433		// This is not really the intended way to do this... but... it works?
434		t.Run(name, func(t *testing.T) {
435			if result.Actual == "FAIL" && result.IsUnexpected {
436				t.Fatal(result.Error)
437			}
438			if expectedResult, ok := assertResults[name]; ok && expectedResult != result.Actual {
439				t.Fatalf("unexpected result: got %s, want %s", result.Actual, assertResults[name])
440			}
441			delete(assertResults, name)
442			if result.Actual == "SKIP" {
443				t.Skip()
444			}
445		})
446	}
447	if *bogoFilter == "" {
448		// Anything still in assertResults did not show up in the results, so we should fail
449		for name, expectedResult := range assertResults {
450			t.Run(name, func(t *testing.T) {
451				t.Fatalf("expected test to run with result %s, but it was not present in the test results", expectedResult)
452			})
453		}
454	}
455}
456
457// bogoResults is a copy of boringssl.googlesource.com/boringssl/testresults.Results
458type bogoResults struct {
459	Version           int            `json:"version"`
460	Interrupted       bool           `json:"interrupted"`
461	PathDelimiter     string         `json:"path_delimiter"`
462	SecondsSinceEpoch float64        `json:"seconds_since_epoch"`
463	NumFailuresByType map[string]int `json:"num_failures_by_type"`
464	Tests             map[string]struct {
465		Actual       string `json:"actual"`
466		Expected     string `json:"expected"`
467		IsUnexpected bool   `json:"is_unexpected"`
468		Error        string `json:"error,omitempty"`
469	} `json:"tests"`
470}
471