1// Copyright 2009 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package rand_test 6 7import ( 8 "bytes" 9 "errors" 10 "fmt" 11 "internal/testenv" 12 "io" 13 "math" 14 . "math/rand" 15 "os" 16 "runtime" 17 "strings" 18 "sync" 19 "testing" 20 "testing/iotest" 21) 22 23const ( 24 numTestSamples = 10000 25) 26 27var rn, kn, wn, fn = GetNormalDistributionParameters() 28var re, ke, we, fe = GetExponentialDistributionParameters() 29 30type statsResults struct { 31 mean float64 32 stddev float64 33 closeEnough float64 34 maxError float64 35} 36 37func nearEqual(a, b, closeEnough, maxError float64) bool { 38 absDiff := math.Abs(a - b) 39 if absDiff < closeEnough { // Necessary when one value is zero and one value is close to zero. 40 return true 41 } 42 return absDiff/max(math.Abs(a), math.Abs(b)) < maxError 43} 44 45var testSeeds = []int64{1, 1754801282, 1698661970, 1550503961} 46 47// checkSimilarDistribution returns success if the mean and stddev of the 48// two statsResults are similar. 49func (sr *statsResults) checkSimilarDistribution(expected *statsResults) error { 50 if !nearEqual(sr.mean, expected.mean, expected.closeEnough, expected.maxError) { 51 s := fmt.Sprintf("mean %v != %v (allowed error %v, %v)", sr.mean, expected.mean, expected.closeEnough, expected.maxError) 52 fmt.Println(s) 53 return errors.New(s) 54 } 55 if !nearEqual(sr.stddev, expected.stddev, expected.closeEnough, expected.maxError) { 56 s := fmt.Sprintf("stddev %v != %v (allowed error %v, %v)", sr.stddev, expected.stddev, expected.closeEnough, expected.maxError) 57 fmt.Println(s) 58 return errors.New(s) 59 } 60 return nil 61} 62 63func getStatsResults(samples []float64) *statsResults { 64 res := new(statsResults) 65 var sum, squaresum float64 66 for _, s := range samples { 67 sum += s 68 squaresum += s * s 69 } 70 res.mean = sum / float64(len(samples)) 71 res.stddev = math.Sqrt(squaresum/float64(len(samples)) - res.mean*res.mean) 72 return res 73} 74 75func checkSampleDistribution(t *testing.T, samples []float64, expected *statsResults) { 76 t.Helper() 77 actual := getStatsResults(samples) 78 err := actual.checkSimilarDistribution(expected) 79 if err != nil { 80 t.Error(err) 81 } 82} 83 84func checkSampleSliceDistributions(t *testing.T, samples []float64, nslices int, expected *statsResults) { 85 t.Helper() 86 chunk := len(samples) / nslices 87 for i := 0; i < nslices; i++ { 88 low := i * chunk 89 var high int 90 if i == nslices-1 { 91 high = len(samples) - 1 92 } else { 93 high = (i + 1) * chunk 94 } 95 checkSampleDistribution(t, samples[low:high], expected) 96 } 97} 98 99// 100// Normal distribution tests 101// 102 103func generateNormalSamples(nsamples int, mean, stddev float64, seed int64) []float64 { 104 r := New(NewSource(seed)) 105 samples := make([]float64, nsamples) 106 for i := range samples { 107 samples[i] = r.NormFloat64()*stddev + mean 108 } 109 return samples 110} 111 112func testNormalDistribution(t *testing.T, nsamples int, mean, stddev float64, seed int64) { 113 //fmt.Printf("testing nsamples=%v mean=%v stddev=%v seed=%v\n", nsamples, mean, stddev, seed); 114 115 samples := generateNormalSamples(nsamples, mean, stddev, seed) 116 errorScale := max(1.0, stddev) // Error scales with stddev 117 expected := &statsResults{mean, stddev, 0.10 * errorScale, 0.08 * errorScale} 118 119 // Make sure that the entire set matches the expected distribution. 120 checkSampleDistribution(t, samples, expected) 121 122 // Make sure that each half of the set matches the expected distribution. 123 checkSampleSliceDistributions(t, samples, 2, expected) 124 125 // Make sure that each 7th of the set matches the expected distribution. 126 checkSampleSliceDistributions(t, samples, 7, expected) 127} 128 129// Actual tests 130 131func TestStandardNormalValues(t *testing.T) { 132 for _, seed := range testSeeds { 133 testNormalDistribution(t, numTestSamples, 0, 1, seed) 134 } 135} 136 137func TestNonStandardNormalValues(t *testing.T) { 138 sdmax := 1000.0 139 mmax := 1000.0 140 if testing.Short() { 141 sdmax = 5 142 mmax = 5 143 } 144 for sd := 0.5; sd < sdmax; sd *= 2 { 145 for m := 0.5; m < mmax; m *= 2 { 146 for _, seed := range testSeeds { 147 testNormalDistribution(t, numTestSamples, m, sd, seed) 148 if testing.Short() { 149 break 150 } 151 } 152 } 153 } 154} 155 156// 157// Exponential distribution tests 158// 159 160func generateExponentialSamples(nsamples int, rate float64, seed int64) []float64 { 161 r := New(NewSource(seed)) 162 samples := make([]float64, nsamples) 163 for i := range samples { 164 samples[i] = r.ExpFloat64() / rate 165 } 166 return samples 167} 168 169func testExponentialDistribution(t *testing.T, nsamples int, rate float64, seed int64) { 170 //fmt.Printf("testing nsamples=%v rate=%v seed=%v\n", nsamples, rate, seed); 171 172 mean := 1 / rate 173 stddev := mean 174 175 samples := generateExponentialSamples(nsamples, rate, seed) 176 errorScale := max(1.0, 1/rate) // Error scales with the inverse of the rate 177 expected := &statsResults{mean, stddev, 0.10 * errorScale, 0.20 * errorScale} 178 179 // Make sure that the entire set matches the expected distribution. 180 checkSampleDistribution(t, samples, expected) 181 182 // Make sure that each half of the set matches the expected distribution. 183 checkSampleSliceDistributions(t, samples, 2, expected) 184 185 // Make sure that each 7th of the set matches the expected distribution. 186 checkSampleSliceDistributions(t, samples, 7, expected) 187} 188 189// Actual tests 190 191func TestStandardExponentialValues(t *testing.T) { 192 for _, seed := range testSeeds { 193 testExponentialDistribution(t, numTestSamples, 1, seed) 194 } 195} 196 197func TestNonStandardExponentialValues(t *testing.T) { 198 for rate := 0.05; rate < 10; rate *= 2 { 199 for _, seed := range testSeeds { 200 testExponentialDistribution(t, numTestSamples, rate, seed) 201 if testing.Short() { 202 break 203 } 204 } 205 } 206} 207 208// 209// Table generation tests 210// 211 212func initNorm() (testKn []uint32, testWn, testFn []float32) { 213 const m1 = 1 << 31 214 var ( 215 dn float64 = rn 216 tn = dn 217 vn float64 = 9.91256303526217e-3 218 ) 219 220 testKn = make([]uint32, 128) 221 testWn = make([]float32, 128) 222 testFn = make([]float32, 128) 223 224 q := vn / math.Exp(-0.5*dn*dn) 225 testKn[0] = uint32((dn / q) * m1) 226 testKn[1] = 0 227 testWn[0] = float32(q / m1) 228 testWn[127] = float32(dn / m1) 229 testFn[0] = 1.0 230 testFn[127] = float32(math.Exp(-0.5 * dn * dn)) 231 for i := 126; i >= 1; i-- { 232 dn = math.Sqrt(-2.0 * math.Log(vn/dn+math.Exp(-0.5*dn*dn))) 233 testKn[i+1] = uint32((dn / tn) * m1) 234 tn = dn 235 testFn[i] = float32(math.Exp(-0.5 * dn * dn)) 236 testWn[i] = float32(dn / m1) 237 } 238 return 239} 240 241func initExp() (testKe []uint32, testWe, testFe []float32) { 242 const m2 = 1 << 32 243 var ( 244 de float64 = re 245 te = de 246 ve float64 = 3.9496598225815571993e-3 247 ) 248 249 testKe = make([]uint32, 256) 250 testWe = make([]float32, 256) 251 testFe = make([]float32, 256) 252 253 q := ve / math.Exp(-de) 254 testKe[0] = uint32((de / q) * m2) 255 testKe[1] = 0 256 testWe[0] = float32(q / m2) 257 testWe[255] = float32(de / m2) 258 testFe[0] = 1.0 259 testFe[255] = float32(math.Exp(-de)) 260 for i := 254; i >= 1; i-- { 261 de = -math.Log(ve/de + math.Exp(-de)) 262 testKe[i+1] = uint32((de / te) * m2) 263 te = de 264 testFe[i] = float32(math.Exp(-de)) 265 testWe[i] = float32(de / m2) 266 } 267 return 268} 269 270// compareUint32Slices returns the first index where the two slices 271// disagree, or <0 if the lengths are the same and all elements 272// are identical. 273func compareUint32Slices(s1, s2 []uint32) int { 274 if len(s1) != len(s2) { 275 if len(s1) > len(s2) { 276 return len(s2) + 1 277 } 278 return len(s1) + 1 279 } 280 for i := range s1 { 281 if s1[i] != s2[i] { 282 return i 283 } 284 } 285 return -1 286} 287 288// compareFloat32Slices returns the first index where the two slices 289// disagree, or <0 if the lengths are the same and all elements 290// are identical. 291func compareFloat32Slices(s1, s2 []float32) int { 292 if len(s1) != len(s2) { 293 if len(s1) > len(s2) { 294 return len(s2) + 1 295 } 296 return len(s1) + 1 297 } 298 for i := range s1 { 299 if !nearEqual(float64(s1[i]), float64(s2[i]), 0, 1e-7) { 300 return i 301 } 302 } 303 return -1 304} 305 306func TestNormTables(t *testing.T) { 307 testKn, testWn, testFn := initNorm() 308 if i := compareUint32Slices(kn[0:], testKn); i >= 0 { 309 t.Errorf("kn disagrees at index %v; %v != %v", i, kn[i], testKn[i]) 310 } 311 if i := compareFloat32Slices(wn[0:], testWn); i >= 0 { 312 t.Errorf("wn disagrees at index %v; %v != %v", i, wn[i], testWn[i]) 313 } 314 if i := compareFloat32Slices(fn[0:], testFn); i >= 0 { 315 t.Errorf("fn disagrees at index %v; %v != %v", i, fn[i], testFn[i]) 316 } 317} 318 319func TestExpTables(t *testing.T) { 320 testKe, testWe, testFe := initExp() 321 if i := compareUint32Slices(ke[0:], testKe); i >= 0 { 322 t.Errorf("ke disagrees at index %v; %v != %v", i, ke[i], testKe[i]) 323 } 324 if i := compareFloat32Slices(we[0:], testWe); i >= 0 { 325 t.Errorf("we disagrees at index %v; %v != %v", i, we[i], testWe[i]) 326 } 327 if i := compareFloat32Slices(fe[0:], testFe); i >= 0 { 328 t.Errorf("fe disagrees at index %v; %v != %v", i, fe[i], testFe[i]) 329 } 330} 331 332func hasSlowFloatingPoint() bool { 333 switch runtime.GOARCH { 334 case "arm": 335 return os.Getenv("GOARM") == "5" || strings.HasSuffix(os.Getenv("GOARM"), ",softfloat") 336 case "mips", "mipsle", "mips64", "mips64le": 337 // Be conservative and assume that all mips boards 338 // have emulated floating point. 339 // TODO: detect what it actually has. 340 return true 341 } 342 return false 343} 344 345func TestFloat32(t *testing.T) { 346 // For issue 6721, the problem came after 7533753 calls, so check 10e6. 347 num := int(10e6) 348 // But do the full amount only on builders (not locally). 349 // But ARM5 floating point emulation is slow (Issue 10749), so 350 // do less for that builder: 351 if testing.Short() && (testenv.Builder() == "" || hasSlowFloatingPoint()) { 352 num /= 100 // 1.72 seconds instead of 172 seconds 353 } 354 355 r := New(NewSource(1)) 356 for ct := 0; ct < num; ct++ { 357 f := r.Float32() 358 if f >= 1 { 359 t.Fatal("Float32() should be in range [0,1). ct:", ct, "f:", f) 360 } 361 } 362} 363 364func testReadUniformity(t *testing.T, n int, seed int64) { 365 r := New(NewSource(seed)) 366 buf := make([]byte, n) 367 nRead, err := r.Read(buf) 368 if err != nil { 369 t.Errorf("Read err %v", err) 370 } 371 if nRead != n { 372 t.Errorf("Read returned unexpected n; %d != %d", nRead, n) 373 } 374 375 // Expect a uniform distribution of byte values, which lie in [0, 255]. 376 var ( 377 mean = 255.0 / 2 378 stddev = 256.0 / math.Sqrt(12.0) 379 errorScale = stddev / math.Sqrt(float64(n)) 380 ) 381 382 expected := &statsResults{mean, stddev, 0.10 * errorScale, 0.08 * errorScale} 383 384 // Cast bytes as floats to use the common distribution-validity checks. 385 samples := make([]float64, n) 386 for i, val := range buf { 387 samples[i] = float64(val) 388 } 389 // Make sure that the entire set matches the expected distribution. 390 checkSampleDistribution(t, samples, expected) 391} 392 393func TestReadUniformity(t *testing.T) { 394 testBufferSizes := []int{ 395 2, 4, 7, 64, 1024, 1 << 16, 1 << 20, 396 } 397 for _, seed := range testSeeds { 398 for _, n := range testBufferSizes { 399 testReadUniformity(t, n, seed) 400 } 401 } 402} 403 404func TestReadEmpty(t *testing.T) { 405 r := New(NewSource(1)) 406 buf := make([]byte, 0) 407 n, err := r.Read(buf) 408 if err != nil { 409 t.Errorf("Read err into empty buffer; %v", err) 410 } 411 if n != 0 { 412 t.Errorf("Read into empty buffer returned unexpected n of %d", n) 413 } 414} 415 416func TestReadByOneByte(t *testing.T) { 417 r := New(NewSource(1)) 418 b1 := make([]byte, 100) 419 _, err := io.ReadFull(iotest.OneByteReader(r), b1) 420 if err != nil { 421 t.Errorf("read by one byte: %v", err) 422 } 423 r = New(NewSource(1)) 424 b2 := make([]byte, 100) 425 _, err = r.Read(b2) 426 if err != nil { 427 t.Errorf("read: %v", err) 428 } 429 if !bytes.Equal(b1, b2) { 430 t.Errorf("read by one byte vs single read:\n%x\n%x", b1, b2) 431 } 432} 433 434func TestReadSeedReset(t *testing.T) { 435 r := New(NewSource(42)) 436 b1 := make([]byte, 128) 437 _, err := r.Read(b1) 438 if err != nil { 439 t.Errorf("read: %v", err) 440 } 441 r.Seed(42) 442 b2 := make([]byte, 128) 443 _, err = r.Read(b2) 444 if err != nil { 445 t.Errorf("read: %v", err) 446 } 447 if !bytes.Equal(b1, b2) { 448 t.Errorf("mismatch after re-seed:\n%x\n%x", b1, b2) 449 } 450} 451 452func TestShuffleSmall(t *testing.T) { 453 // Check that Shuffle allows n=0 and n=1, but that swap is never called for them. 454 r := New(NewSource(1)) 455 for n := 0; n <= 1; n++ { 456 r.Shuffle(n, func(i, j int) { t.Fatalf("swap called, n=%d i=%d j=%d", n, i, j) }) 457 } 458} 459 460// encodePerm converts from a permuted slice of length n, such as Perm generates, to an int in [0, n!). 461// See https://en.wikipedia.org/wiki/Lehmer_code. 462// encodePerm modifies the input slice. 463func encodePerm(s []int) int { 464 // Convert to Lehmer code. 465 for i, x := range s { 466 r := s[i+1:] 467 for j, y := range r { 468 if y > x { 469 r[j]-- 470 } 471 } 472 } 473 // Convert to int in [0, n!). 474 m := 0 475 fact := 1 476 for i := len(s) - 1; i >= 0; i-- { 477 m += s[i] * fact 478 fact *= len(s) - i 479 } 480 return m 481} 482 483// TestUniformFactorial tests several ways of generating a uniform value in [0, n!). 484func TestUniformFactorial(t *testing.T) { 485 r := New(NewSource(testSeeds[0])) 486 top := 6 487 if testing.Short() { 488 top = 3 489 } 490 for n := 3; n <= top; n++ { 491 t.Run(fmt.Sprintf("n=%d", n), func(t *testing.T) { 492 // Calculate n!. 493 nfact := 1 494 for i := 2; i <= n; i++ { 495 nfact *= i 496 } 497 498 // Test a few different ways to generate a uniform distribution. 499 p := make([]int, n) // re-usable slice for Shuffle generator 500 tests := [...]struct { 501 name string 502 fn func() int 503 }{ 504 {name: "Int31n", fn: func() int { return int(r.Int31n(int32(nfact))) }}, 505 {name: "int31n", fn: func() int { return int(Int31nForTest(r, int32(nfact))) }}, 506 {name: "Perm", fn: func() int { return encodePerm(r.Perm(n)) }}, 507 {name: "Shuffle", fn: func() int { 508 // Generate permutation using Shuffle. 509 for i := range p { 510 p[i] = i 511 } 512 r.Shuffle(n, func(i, j int) { p[i], p[j] = p[j], p[i] }) 513 return encodePerm(p) 514 }}, 515 } 516 517 for _, test := range tests { 518 t.Run(test.name, func(t *testing.T) { 519 // Gather chi-squared values and check that they follow 520 // the expected normal distribution given n!-1 degrees of freedom. 521 // See https://en.wikipedia.org/wiki/Pearson%27s_chi-squared_test and 522 // https://www.johndcook.com/Beautiful_Testing_ch10.pdf. 523 nsamples := 10 * nfact 524 if nsamples < 200 { 525 nsamples = 200 526 } 527 samples := make([]float64, nsamples) 528 for i := range samples { 529 // Generate some uniformly distributed values and count their occurrences. 530 const iters = 1000 531 counts := make([]int, nfact) 532 for i := 0; i < iters; i++ { 533 counts[test.fn()]++ 534 } 535 // Calculate chi-squared and add to samples. 536 want := iters / float64(nfact) 537 var χ2 float64 538 for _, have := range counts { 539 err := float64(have) - want 540 χ2 += err * err 541 } 542 χ2 /= want 543 samples[i] = χ2 544 } 545 546 // Check that our samples approximate the appropriate normal distribution. 547 dof := float64(nfact - 1) 548 expected := &statsResults{mean: dof, stddev: math.Sqrt(2 * dof)} 549 errorScale := max(1.0, expected.stddev) 550 expected.closeEnough = 0.10 * errorScale 551 expected.maxError = 0.08 // TODO: What is the right value here? See issue 21211. 552 checkSampleDistribution(t, samples, expected) 553 }) 554 } 555 }) 556 } 557} 558 559// Benchmarks 560 561func BenchmarkInt63Threadsafe(b *testing.B) { 562 for n := b.N; n > 0; n-- { 563 Int63() 564 } 565} 566 567func BenchmarkInt63ThreadsafeParallel(b *testing.B) { 568 b.RunParallel(func(pb *testing.PB) { 569 for pb.Next() { 570 Int63() 571 } 572 }) 573} 574 575func BenchmarkInt63Unthreadsafe(b *testing.B) { 576 r := New(NewSource(1)) 577 for n := b.N; n > 0; n-- { 578 r.Int63() 579 } 580} 581 582func BenchmarkIntn1000(b *testing.B) { 583 r := New(NewSource(1)) 584 for n := b.N; n > 0; n-- { 585 r.Intn(1000) 586 } 587} 588 589func BenchmarkInt63n1000(b *testing.B) { 590 r := New(NewSource(1)) 591 for n := b.N; n > 0; n-- { 592 r.Int63n(1000) 593 } 594} 595 596func BenchmarkInt31n1000(b *testing.B) { 597 r := New(NewSource(1)) 598 for n := b.N; n > 0; n-- { 599 r.Int31n(1000) 600 } 601} 602 603func BenchmarkFloat32(b *testing.B) { 604 r := New(NewSource(1)) 605 for n := b.N; n > 0; n-- { 606 r.Float32() 607 } 608} 609 610func BenchmarkFloat64(b *testing.B) { 611 r := New(NewSource(1)) 612 for n := b.N; n > 0; n-- { 613 r.Float64() 614 } 615} 616 617func BenchmarkPerm3(b *testing.B) { 618 r := New(NewSource(1)) 619 for n := b.N; n > 0; n-- { 620 r.Perm(3) 621 } 622} 623 624func BenchmarkPerm30(b *testing.B) { 625 r := New(NewSource(1)) 626 for n := b.N; n > 0; n-- { 627 r.Perm(30) 628 } 629} 630 631func BenchmarkPerm30ViaShuffle(b *testing.B) { 632 r := New(NewSource(1)) 633 for n := b.N; n > 0; n-- { 634 p := make([]int, 30) 635 for i := range p { 636 p[i] = i 637 } 638 r.Shuffle(30, func(i, j int) { p[i], p[j] = p[j], p[i] }) 639 } 640} 641 642// BenchmarkShuffleOverhead uses a minimal swap function 643// to measure just the shuffling overhead. 644func BenchmarkShuffleOverhead(b *testing.B) { 645 r := New(NewSource(1)) 646 for n := b.N; n > 0; n-- { 647 r.Shuffle(52, func(i, j int) { 648 if i < 0 || i >= 52 || j < 0 || j >= 52 { 649 b.Fatalf("bad swap(%d, %d)", i, j) 650 } 651 }) 652 } 653} 654 655func BenchmarkRead3(b *testing.B) { 656 r := New(NewSource(1)) 657 buf := make([]byte, 3) 658 b.ResetTimer() 659 for n := b.N; n > 0; n-- { 660 r.Read(buf) 661 } 662} 663 664func BenchmarkRead64(b *testing.B) { 665 r := New(NewSource(1)) 666 buf := make([]byte, 64) 667 b.ResetTimer() 668 for n := b.N; n > 0; n-- { 669 r.Read(buf) 670 } 671} 672 673func BenchmarkRead1000(b *testing.B) { 674 r := New(NewSource(1)) 675 buf := make([]byte, 1000) 676 b.ResetTimer() 677 for n := b.N; n > 0; n-- { 678 r.Read(buf) 679 } 680} 681 682func BenchmarkConcurrent(b *testing.B) { 683 const goroutines = 4 684 var wg sync.WaitGroup 685 wg.Add(goroutines) 686 for i := 0; i < goroutines; i++ { 687 go func() { 688 defer wg.Done() 689 for n := b.N; n > 0; n-- { 690 Int63() 691 } 692 }() 693 } 694 wg.Wait() 695} 696