1// Copyright 2009 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package tls 6 7import ( 8 "bytes" 9 "crypto/x509" 10 "encoding/hex" 11 "math" 12 "math/rand" 13 "reflect" 14 "strings" 15 "testing" 16 "testing/quick" 17 "time" 18) 19 20var tests = []handshakeMessage{ 21 &clientHelloMsg{}, 22 &serverHelloMsg{}, 23 &finishedMsg{}, 24 25 &certificateMsg{}, 26 &certificateRequestMsg{}, 27 &certificateVerifyMsg{ 28 hasSignatureAlgorithm: true, 29 }, 30 &certificateStatusMsg{}, 31 &clientKeyExchangeMsg{}, 32 &newSessionTicketMsg{}, 33 &encryptedExtensionsMsg{}, 34 &endOfEarlyDataMsg{}, 35 &keyUpdateMsg{}, 36 &newSessionTicketMsgTLS13{}, 37 &certificateRequestMsgTLS13{}, 38 &certificateMsgTLS13{}, 39 &SessionState{}, 40} 41 42func mustMarshal(t *testing.T, msg handshakeMessage) []byte { 43 t.Helper() 44 b, err := msg.marshal() 45 if err != nil { 46 t.Fatal(err) 47 } 48 return b 49} 50 51func TestMarshalUnmarshal(t *testing.T) { 52 rand := rand.New(rand.NewSource(time.Now().UnixNano())) 53 54 for i, m := range tests { 55 ty := reflect.ValueOf(m).Type() 56 t.Run(ty.String(), func(t *testing.T) { 57 n := 100 58 if testing.Short() { 59 n = 5 60 } 61 for j := 0; j < n; j++ { 62 v, ok := quick.Value(ty, rand) 63 if !ok { 64 t.Errorf("#%d: failed to create value", i) 65 break 66 } 67 68 m1 := v.Interface().(handshakeMessage) 69 marshaled := mustMarshal(t, m1) 70 if !m.unmarshal(marshaled) { 71 t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled) 72 break 73 } 74 75 if m, ok := m.(*SessionState); ok { 76 m.activeCertHandles = nil 77 } 78 79 // clientHelloMsg and serverHelloMsg, when unmarshalled, store 80 // their original representation, for later use in the handshake 81 // transcript. In order to prevent DeepEqual from failing since 82 // we didn't create the original message via unmarshalling, nil 83 // the field. 84 switch t := m.(type) { 85 case *clientHelloMsg: 86 t.original = nil 87 case *serverHelloMsg: 88 t.original = nil 89 } 90 91 if !reflect.DeepEqual(m1, m) { 92 t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled) 93 break 94 } 95 96 if i >= 3 { 97 // The first three message types (ClientHello, 98 // ServerHello and Finished) are allowed to 99 // have parsable prefixes because the extension 100 // data is optional and the length of the 101 // Finished varies across versions. 102 for j := 0; j < len(marshaled); j++ { 103 if m.unmarshal(marshaled[0:j]) { 104 t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1) 105 break 106 } 107 } 108 } 109 } 110 }) 111 } 112} 113 114func TestFuzz(t *testing.T) { 115 rand := rand.New(rand.NewSource(0)) 116 for _, m := range tests { 117 for j := 0; j < 1000; j++ { 118 len := rand.Intn(1000) 119 bytes := randomBytes(len, rand) 120 // This just looks for crashes due to bounds errors etc. 121 m.unmarshal(bytes) 122 } 123 } 124} 125 126func randomBytes(n int, rand *rand.Rand) []byte { 127 r := make([]byte, n) 128 if _, err := rand.Read(r); err != nil { 129 panic("rand.Read failed: " + err.Error()) 130 } 131 return r 132} 133 134func randomString(n int, rand *rand.Rand) string { 135 b := randomBytes(n, rand) 136 return string(b) 137} 138 139func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { 140 m := &clientHelloMsg{} 141 m.vers = uint16(rand.Intn(65536)) 142 m.random = randomBytes(32, rand) 143 m.sessionId = randomBytes(rand.Intn(32), rand) 144 m.cipherSuites = make([]uint16, rand.Intn(63)+1) 145 for i := 0; i < len(m.cipherSuites); i++ { 146 cs := uint16(rand.Int31()) 147 if cs == scsvRenegotiation { 148 cs += 1 149 } 150 m.cipherSuites[i] = cs 151 } 152 m.compressionMethods = randomBytes(rand.Intn(63)+1, rand) 153 if rand.Intn(10) > 5 { 154 m.serverName = randomString(rand.Intn(255), rand) 155 for strings.HasSuffix(m.serverName, ".") { 156 m.serverName = m.serverName[:len(m.serverName)-1] 157 } 158 } 159 m.ocspStapling = rand.Intn(10) > 5 160 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) 161 m.supportedCurves = make([]CurveID, rand.Intn(5)+1) 162 for i := range m.supportedCurves { 163 m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1) 164 } 165 if rand.Intn(10) > 5 { 166 m.ticketSupported = true 167 if rand.Intn(10) > 5 { 168 m.sessionTicket = randomBytes(rand.Intn(300), rand) 169 } else { 170 m.sessionTicket = make([]byte, 0) 171 } 172 } 173 if rand.Intn(10) > 5 { 174 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms() 175 } 176 if rand.Intn(10) > 5 { 177 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms() 178 } 179 for i := 0; i < rand.Intn(5); i++ { 180 m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand)) 181 } 182 if rand.Intn(10) > 5 { 183 m.scts = true 184 } 185 if rand.Intn(10) > 5 { 186 m.secureRenegotiationSupported = true 187 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) 188 } 189 if rand.Intn(10) > 5 { 190 m.extendedMasterSecret = true 191 } 192 for i := 0; i < rand.Intn(5); i++ { 193 m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1)) 194 } 195 if rand.Intn(10) > 5 { 196 m.cookie = randomBytes(rand.Intn(500)+1, rand) 197 } 198 for i := 0; i < rand.Intn(5); i++ { 199 var ks keyShare 200 ks.group = CurveID(rand.Intn(30000) + 1) 201 ks.data = randomBytes(rand.Intn(200)+1, rand) 202 m.keyShares = append(m.keyShares, ks) 203 } 204 switch rand.Intn(3) { 205 case 1: 206 m.pskModes = []uint8{pskModeDHE} 207 case 2: 208 m.pskModes = []uint8{pskModeDHE, pskModePlain} 209 } 210 for i := 0; i < rand.Intn(5); i++ { 211 var psk pskIdentity 212 psk.obfuscatedTicketAge = uint32(rand.Intn(500000)) 213 psk.label = randomBytes(rand.Intn(500)+1, rand) 214 m.pskIdentities = append(m.pskIdentities, psk) 215 m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand)) 216 } 217 if rand.Intn(10) > 5 { 218 m.quicTransportParameters = randomBytes(rand.Intn(500), rand) 219 } 220 if rand.Intn(10) > 5 { 221 m.earlyData = true 222 } 223 224 return reflect.ValueOf(m) 225} 226 227func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { 228 m := &serverHelloMsg{} 229 m.vers = uint16(rand.Intn(65536)) 230 m.random = randomBytes(32, rand) 231 m.sessionId = randomBytes(rand.Intn(32), rand) 232 m.cipherSuite = uint16(rand.Int31()) 233 m.compressionMethod = uint8(rand.Intn(256)) 234 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) 235 236 if rand.Intn(10) > 5 { 237 m.ocspStapling = true 238 } 239 if rand.Intn(10) > 5 { 240 m.ticketSupported = true 241 } 242 if rand.Intn(10) > 5 { 243 m.alpnProtocol = randomString(rand.Intn(32)+1, rand) 244 } 245 246 for i := 0; i < rand.Intn(4); i++ { 247 m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand)) 248 } 249 250 if rand.Intn(10) > 5 { 251 m.secureRenegotiationSupported = true 252 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) 253 } 254 if rand.Intn(10) > 5 { 255 m.extendedMasterSecret = true 256 } 257 if rand.Intn(10) > 5 { 258 m.supportedVersion = uint16(rand.Intn(0xffff) + 1) 259 } 260 if rand.Intn(10) > 5 { 261 m.cookie = randomBytes(rand.Intn(500)+1, rand) 262 } 263 if rand.Intn(10) > 5 { 264 for i := 0; i < rand.Intn(5); i++ { 265 m.serverShare.group = CurveID(rand.Intn(30000) + 1) 266 m.serverShare.data = randomBytes(rand.Intn(200)+1, rand) 267 } 268 } else if rand.Intn(10) > 5 { 269 m.selectedGroup = CurveID(rand.Intn(30000) + 1) 270 } 271 if rand.Intn(10) > 5 { 272 m.selectedIdentityPresent = true 273 m.selectedIdentity = uint16(rand.Intn(0xffff)) 274 } 275 if rand.Intn(10) > 5 { 276 m.encryptedClientHello = randomBytes(rand.Intn(50)+1, rand) 277 } 278 if rand.Intn(10) > 5 { 279 m.serverNameAck = rand.Intn(2) == 1 280 } 281 282 return reflect.ValueOf(m) 283} 284 285func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value { 286 m := &encryptedExtensionsMsg{} 287 288 if rand.Intn(10) > 5 { 289 m.alpnProtocol = randomString(rand.Intn(32)+1, rand) 290 } 291 if rand.Intn(10) > 5 { 292 m.earlyData = true 293 } 294 295 return reflect.ValueOf(m) 296} 297 298func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { 299 m := &certificateMsg{} 300 numCerts := rand.Intn(20) 301 m.certificates = make([][]byte, numCerts) 302 for i := 0; i < numCerts; i++ { 303 m.certificates[i] = randomBytes(rand.Intn(10)+1, rand) 304 } 305 return reflect.ValueOf(m) 306} 307 308func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value { 309 m := &certificateRequestMsg{} 310 m.certificateTypes = randomBytes(rand.Intn(5)+1, rand) 311 for i := 0; i < rand.Intn(100); i++ { 312 m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand)) 313 } 314 return reflect.ValueOf(m) 315} 316 317func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value { 318 m := &certificateVerifyMsg{} 319 m.hasSignatureAlgorithm = true 320 m.signatureAlgorithm = SignatureScheme(rand.Intn(30000)) 321 m.signature = randomBytes(rand.Intn(15)+1, rand) 322 return reflect.ValueOf(m) 323} 324 325func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { 326 m := &certificateStatusMsg{} 327 m.response = randomBytes(rand.Intn(10)+1, rand) 328 return reflect.ValueOf(m) 329} 330 331func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value { 332 m := &clientKeyExchangeMsg{} 333 m.ciphertext = randomBytes(rand.Intn(1000)+1, rand) 334 return reflect.ValueOf(m) 335} 336 337func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value { 338 m := &finishedMsg{} 339 m.verifyData = randomBytes(12, rand) 340 return reflect.ValueOf(m) 341} 342 343func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value { 344 m := &newSessionTicketMsg{} 345 m.ticket = randomBytes(rand.Intn(4), rand) 346 return reflect.ValueOf(m) 347} 348 349var sessionTestCerts []*x509.Certificate 350 351func init() { 352 cert, err := x509.ParseCertificate(testRSACertificate) 353 if err != nil { 354 panic(err) 355 } 356 sessionTestCerts = append(sessionTestCerts, cert) 357 cert, err = x509.ParseCertificate(testRSACertificateIssuer) 358 if err != nil { 359 panic(err) 360 } 361 sessionTestCerts = append(sessionTestCerts, cert) 362} 363 364func (*SessionState) Generate(rand *rand.Rand, size int) reflect.Value { 365 s := &SessionState{} 366 isTLS13 := rand.Intn(10) > 5 367 if isTLS13 { 368 s.version = VersionTLS13 369 } else { 370 s.version = uint16(rand.Intn(VersionTLS13)) 371 } 372 s.isClient = rand.Intn(10) > 5 373 s.cipherSuite = uint16(rand.Intn(math.MaxUint16)) 374 s.createdAt = uint64(rand.Int63()) 375 s.secret = randomBytes(rand.Intn(100)+1, rand) 376 for n, i := rand.Intn(3), 0; i < n; i++ { 377 s.Extra = append(s.Extra, randomBytes(rand.Intn(100), rand)) 378 } 379 if rand.Intn(10) > 5 { 380 s.EarlyData = true 381 } 382 if rand.Intn(10) > 5 { 383 s.extMasterSecret = true 384 } 385 if s.isClient || rand.Intn(10) > 5 { 386 if rand.Intn(10) > 5 { 387 s.peerCertificates = sessionTestCerts 388 } else { 389 s.peerCertificates = sessionTestCerts[:1] 390 } 391 } 392 if rand.Intn(10) > 5 && s.peerCertificates != nil { 393 s.ocspResponse = randomBytes(rand.Intn(100)+1, rand) 394 } 395 if rand.Intn(10) > 5 && s.peerCertificates != nil { 396 for i := 0; i < rand.Intn(2)+1; i++ { 397 s.scts = append(s.scts, randomBytes(rand.Intn(500)+1, rand)) 398 } 399 } 400 if len(s.peerCertificates) > 0 { 401 for i := 0; i < rand.Intn(3); i++ { 402 if rand.Intn(10) > 5 { 403 s.verifiedChains = append(s.verifiedChains, s.peerCertificates) 404 } else { 405 s.verifiedChains = append(s.verifiedChains, s.peerCertificates[:1]) 406 } 407 } 408 } 409 if rand.Intn(10) > 5 && s.EarlyData { 410 s.alpnProtocol = string(randomBytes(rand.Intn(10), rand)) 411 } 412 if s.isClient { 413 if isTLS13 { 414 s.useBy = uint64(rand.Int63()) 415 s.ageAdd = uint32(rand.Int63() & math.MaxUint32) 416 } 417 } 418 return reflect.ValueOf(s) 419} 420 421func (s *SessionState) marshal() ([]byte, error) { return s.Bytes() } 422func (s *SessionState) unmarshal(b []byte) bool { 423 ss, err := ParseSessionState(b) 424 if err != nil { 425 return false 426 } 427 *s = *ss 428 return true 429} 430 431func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value { 432 m := &endOfEarlyDataMsg{} 433 return reflect.ValueOf(m) 434} 435 436func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value { 437 m := &keyUpdateMsg{} 438 m.updateRequested = rand.Intn(10) > 5 439 return reflect.ValueOf(m) 440} 441 442func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 443 m := &newSessionTicketMsgTLS13{} 444 m.lifetime = uint32(rand.Intn(500000)) 445 m.ageAdd = uint32(rand.Intn(500000)) 446 m.nonce = randomBytes(rand.Intn(100), rand) 447 m.label = randomBytes(rand.Intn(1000), rand) 448 if rand.Intn(10) > 5 { 449 m.maxEarlyData = uint32(rand.Intn(500000)) 450 } 451 return reflect.ValueOf(m) 452} 453 454func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 455 m := &certificateRequestMsgTLS13{} 456 if rand.Intn(10) > 5 { 457 m.ocspStapling = true 458 } 459 if rand.Intn(10) > 5 { 460 m.scts = true 461 } 462 if rand.Intn(10) > 5 { 463 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms() 464 } 465 if rand.Intn(10) > 5 { 466 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms() 467 } 468 if rand.Intn(10) > 5 { 469 m.certificateAuthorities = make([][]byte, 3) 470 for i := 0; i < 3; i++ { 471 m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand) 472 } 473 } 474 return reflect.ValueOf(m) 475} 476 477func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 478 m := &certificateMsgTLS13{} 479 for i := 0; i < rand.Intn(2)+1; i++ { 480 m.certificate.Certificate = append( 481 m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand)) 482 } 483 if rand.Intn(10) > 5 { 484 m.ocspStapling = true 485 m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand) 486 } 487 if rand.Intn(10) > 5 { 488 m.scts = true 489 for i := 0; i < rand.Intn(2)+1; i++ { 490 m.certificate.SignedCertificateTimestamps = append( 491 m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand)) 492 } 493 } 494 return reflect.ValueOf(m) 495} 496 497func TestRejectEmptySCTList(t *testing.T) { 498 // RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid. 499 500 var random [32]byte 501 sct := []byte{0x42, 0x42, 0x42, 0x42} 502 serverHello := &serverHelloMsg{ 503 vers: VersionTLS12, 504 random: random[:], 505 scts: [][]byte{sct}, 506 } 507 serverHelloBytes := mustMarshal(t, serverHello) 508 509 var serverHelloCopy serverHelloMsg 510 if !serverHelloCopy.unmarshal(serverHelloBytes) { 511 t.Fatal("Failed to unmarshal initial message") 512 } 513 514 // Change serverHelloBytes so that the SCT list is empty 515 i := bytes.Index(serverHelloBytes, sct) 516 if i < 0 { 517 t.Fatal("Cannot find SCT in ServerHello") 518 } 519 520 var serverHelloEmptySCT []byte 521 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...) 522 // Append the extension length and SCT list length for an empty list. 523 serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...) 524 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...) 525 526 // Update the handshake message length. 527 serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16) 528 serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8) 529 serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4) 530 531 // Update the extensions length 532 serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8) 533 serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44)) 534 535 if serverHelloCopy.unmarshal(serverHelloEmptySCT) { 536 t.Fatal("Unmarshaled ServerHello with empty SCT list") 537 } 538} 539 540func TestRejectEmptySCT(t *testing.T) { 541 // Not only must the SCT list be non-empty, but the SCT elements must 542 // not be zero length. 543 544 var random [32]byte 545 serverHello := &serverHelloMsg{ 546 vers: VersionTLS12, 547 random: random[:], 548 scts: [][]byte{nil}, 549 } 550 serverHelloBytes := mustMarshal(t, serverHello) 551 552 var serverHelloCopy serverHelloMsg 553 if serverHelloCopy.unmarshal(serverHelloBytes) { 554 t.Fatal("Unmarshaled ServerHello with zero-length SCT") 555 } 556} 557 558func TestRejectDuplicateExtensions(t *testing.T) { 559 clientHelloBytes, err := hex.DecodeString("010000440303000000000000000000000000000000000000000000000000000000000000000000000000001c0000000a000800000568656c6c6f0000000a000800000568656c6c6f") 560 if err != nil { 561 t.Fatalf("failed to decode test ClientHello: %s", err) 562 } 563 var clientHelloCopy clientHelloMsg 564 if clientHelloCopy.unmarshal(clientHelloBytes) { 565 t.Error("Unmarshaled ClientHello with duplicate extensions") 566 } 567 568 serverHelloBytes, err := hex.DecodeString("02000030030300000000000000000000000000000000000000000000000000000000000000000000000000080005000000050000") 569 if err != nil { 570 t.Fatalf("failed to decode test ServerHello: %s", err) 571 } 572 var serverHelloCopy serverHelloMsg 573 if serverHelloCopy.unmarshal(serverHelloBytes) { 574 t.Fatal("Unmarshaled ServerHello with duplicate extensions") 575 } 576} 577