1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package rand_test
6
7import (
8	"bytes"
9	"errors"
10	"fmt"
11	"internal/testenv"
12	"io"
13	"math"
14	. "math/rand"
15	"os"
16	"runtime"
17	"strings"
18	"sync"
19	"testing"
20	"testing/iotest"
21)
22
23const (
24	numTestSamples = 10000
25)
26
27var rn, kn, wn, fn = GetNormalDistributionParameters()
28var re, ke, we, fe = GetExponentialDistributionParameters()
29
30type statsResults struct {
31	mean        float64
32	stddev      float64
33	closeEnough float64
34	maxError    float64
35}
36
37func nearEqual(a, b, closeEnough, maxError float64) bool {
38	absDiff := math.Abs(a - b)
39	if absDiff < closeEnough { // Necessary when one value is zero and one value is close to zero.
40		return true
41	}
42	return absDiff/max(math.Abs(a), math.Abs(b)) < maxError
43}
44
45var testSeeds = []int64{1, 1754801282, 1698661970, 1550503961}
46
47// checkSimilarDistribution returns success if the mean and stddev of the
48// two statsResults are similar.
49func (sr *statsResults) checkSimilarDistribution(expected *statsResults) error {
50	if !nearEqual(sr.mean, expected.mean, expected.closeEnough, expected.maxError) {
51		s := fmt.Sprintf("mean %v != %v (allowed error %v, %v)", sr.mean, expected.mean, expected.closeEnough, expected.maxError)
52		fmt.Println(s)
53		return errors.New(s)
54	}
55	if !nearEqual(sr.stddev, expected.stddev, expected.closeEnough, expected.maxError) {
56		s := fmt.Sprintf("stddev %v != %v (allowed error %v, %v)", sr.stddev, expected.stddev, expected.closeEnough, expected.maxError)
57		fmt.Println(s)
58		return errors.New(s)
59	}
60	return nil
61}
62
63func getStatsResults(samples []float64) *statsResults {
64	res := new(statsResults)
65	var sum, squaresum float64
66	for _, s := range samples {
67		sum += s
68		squaresum += s * s
69	}
70	res.mean = sum / float64(len(samples))
71	res.stddev = math.Sqrt(squaresum/float64(len(samples)) - res.mean*res.mean)
72	return res
73}
74
75func checkSampleDistribution(t *testing.T, samples []float64, expected *statsResults) {
76	t.Helper()
77	actual := getStatsResults(samples)
78	err := actual.checkSimilarDistribution(expected)
79	if err != nil {
80		t.Error(err)
81	}
82}
83
84func checkSampleSliceDistributions(t *testing.T, samples []float64, nslices int, expected *statsResults) {
85	t.Helper()
86	chunk := len(samples) / nslices
87	for i := 0; i < nslices; i++ {
88		low := i * chunk
89		var high int
90		if i == nslices-1 {
91			high = len(samples) - 1
92		} else {
93			high = (i + 1) * chunk
94		}
95		checkSampleDistribution(t, samples[low:high], expected)
96	}
97}
98
99//
100// Normal distribution tests
101//
102
103func generateNormalSamples(nsamples int, mean, stddev float64, seed int64) []float64 {
104	r := New(NewSource(seed))
105	samples := make([]float64, nsamples)
106	for i := range samples {
107		samples[i] = r.NormFloat64()*stddev + mean
108	}
109	return samples
110}
111
112func testNormalDistribution(t *testing.T, nsamples int, mean, stddev float64, seed int64) {
113	//fmt.Printf("testing nsamples=%v mean=%v stddev=%v seed=%v\n", nsamples, mean, stddev, seed);
114
115	samples := generateNormalSamples(nsamples, mean, stddev, seed)
116	errorScale := max(1.0, stddev) // Error scales with stddev
117	expected := &statsResults{mean, stddev, 0.10 * errorScale, 0.08 * errorScale}
118
119	// Make sure that the entire set matches the expected distribution.
120	checkSampleDistribution(t, samples, expected)
121
122	// Make sure that each half of the set matches the expected distribution.
123	checkSampleSliceDistributions(t, samples, 2, expected)
124
125	// Make sure that each 7th of the set matches the expected distribution.
126	checkSampleSliceDistributions(t, samples, 7, expected)
127}
128
129// Actual tests
130
131func TestStandardNormalValues(t *testing.T) {
132	for _, seed := range testSeeds {
133		testNormalDistribution(t, numTestSamples, 0, 1, seed)
134	}
135}
136
137func TestNonStandardNormalValues(t *testing.T) {
138	sdmax := 1000.0
139	mmax := 1000.0
140	if testing.Short() {
141		sdmax = 5
142		mmax = 5
143	}
144	for sd := 0.5; sd < sdmax; sd *= 2 {
145		for m := 0.5; m < mmax; m *= 2 {
146			for _, seed := range testSeeds {
147				testNormalDistribution(t, numTestSamples, m, sd, seed)
148				if testing.Short() {
149					break
150				}
151			}
152		}
153	}
154}
155
156//
157// Exponential distribution tests
158//
159
160func generateExponentialSamples(nsamples int, rate float64, seed int64) []float64 {
161	r := New(NewSource(seed))
162	samples := make([]float64, nsamples)
163	for i := range samples {
164		samples[i] = r.ExpFloat64() / rate
165	}
166	return samples
167}
168
169func testExponentialDistribution(t *testing.T, nsamples int, rate float64, seed int64) {
170	//fmt.Printf("testing nsamples=%v rate=%v seed=%v\n", nsamples, rate, seed);
171
172	mean := 1 / rate
173	stddev := mean
174
175	samples := generateExponentialSamples(nsamples, rate, seed)
176	errorScale := max(1.0, 1/rate) // Error scales with the inverse of the rate
177	expected := &statsResults{mean, stddev, 0.10 * errorScale, 0.20 * errorScale}
178
179	// Make sure that the entire set matches the expected distribution.
180	checkSampleDistribution(t, samples, expected)
181
182	// Make sure that each half of the set matches the expected distribution.
183	checkSampleSliceDistributions(t, samples, 2, expected)
184
185	// Make sure that each 7th of the set matches the expected distribution.
186	checkSampleSliceDistributions(t, samples, 7, expected)
187}
188
189// Actual tests
190
191func TestStandardExponentialValues(t *testing.T) {
192	for _, seed := range testSeeds {
193		testExponentialDistribution(t, numTestSamples, 1, seed)
194	}
195}
196
197func TestNonStandardExponentialValues(t *testing.T) {
198	for rate := 0.05; rate < 10; rate *= 2 {
199		for _, seed := range testSeeds {
200			testExponentialDistribution(t, numTestSamples, rate, seed)
201			if testing.Short() {
202				break
203			}
204		}
205	}
206}
207
208//
209// Table generation tests
210//
211
212func initNorm() (testKn []uint32, testWn, testFn []float32) {
213	const m1 = 1 << 31
214	var (
215		dn float64 = rn
216		tn         = dn
217		vn float64 = 9.91256303526217e-3
218	)
219
220	testKn = make([]uint32, 128)
221	testWn = make([]float32, 128)
222	testFn = make([]float32, 128)
223
224	q := vn / math.Exp(-0.5*dn*dn)
225	testKn[0] = uint32((dn / q) * m1)
226	testKn[1] = 0
227	testWn[0] = float32(q / m1)
228	testWn[127] = float32(dn / m1)
229	testFn[0] = 1.0
230	testFn[127] = float32(math.Exp(-0.5 * dn * dn))
231	for i := 126; i >= 1; i-- {
232		dn = math.Sqrt(-2.0 * math.Log(vn/dn+math.Exp(-0.5*dn*dn)))
233		testKn[i+1] = uint32((dn / tn) * m1)
234		tn = dn
235		testFn[i] = float32(math.Exp(-0.5 * dn * dn))
236		testWn[i] = float32(dn / m1)
237	}
238	return
239}
240
241func initExp() (testKe []uint32, testWe, testFe []float32) {
242	const m2 = 1 << 32
243	var (
244		de float64 = re
245		te         = de
246		ve float64 = 3.9496598225815571993e-3
247	)
248
249	testKe = make([]uint32, 256)
250	testWe = make([]float32, 256)
251	testFe = make([]float32, 256)
252
253	q := ve / math.Exp(-de)
254	testKe[0] = uint32((de / q) * m2)
255	testKe[1] = 0
256	testWe[0] = float32(q / m2)
257	testWe[255] = float32(de / m2)
258	testFe[0] = 1.0
259	testFe[255] = float32(math.Exp(-de))
260	for i := 254; i >= 1; i-- {
261		de = -math.Log(ve/de + math.Exp(-de))
262		testKe[i+1] = uint32((de / te) * m2)
263		te = de
264		testFe[i] = float32(math.Exp(-de))
265		testWe[i] = float32(de / m2)
266	}
267	return
268}
269
270// compareUint32Slices returns the first index where the two slices
271// disagree, or <0 if the lengths are the same and all elements
272// are identical.
273func compareUint32Slices(s1, s2 []uint32) int {
274	if len(s1) != len(s2) {
275		if len(s1) > len(s2) {
276			return len(s2) + 1
277		}
278		return len(s1) + 1
279	}
280	for i := range s1 {
281		if s1[i] != s2[i] {
282			return i
283		}
284	}
285	return -1
286}
287
288// compareFloat32Slices returns the first index where the two slices
289// disagree, or <0 if the lengths are the same and all elements
290// are identical.
291func compareFloat32Slices(s1, s2 []float32) int {
292	if len(s1) != len(s2) {
293		if len(s1) > len(s2) {
294			return len(s2) + 1
295		}
296		return len(s1) + 1
297	}
298	for i := range s1 {
299		if !nearEqual(float64(s1[i]), float64(s2[i]), 0, 1e-7) {
300			return i
301		}
302	}
303	return -1
304}
305
306func TestNormTables(t *testing.T) {
307	testKn, testWn, testFn := initNorm()
308	if i := compareUint32Slices(kn[0:], testKn); i >= 0 {
309		t.Errorf("kn disagrees at index %v; %v != %v", i, kn[i], testKn[i])
310	}
311	if i := compareFloat32Slices(wn[0:], testWn); i >= 0 {
312		t.Errorf("wn disagrees at index %v; %v != %v", i, wn[i], testWn[i])
313	}
314	if i := compareFloat32Slices(fn[0:], testFn); i >= 0 {
315		t.Errorf("fn disagrees at index %v; %v != %v", i, fn[i], testFn[i])
316	}
317}
318
319func TestExpTables(t *testing.T) {
320	testKe, testWe, testFe := initExp()
321	if i := compareUint32Slices(ke[0:], testKe); i >= 0 {
322		t.Errorf("ke disagrees at index %v; %v != %v", i, ke[i], testKe[i])
323	}
324	if i := compareFloat32Slices(we[0:], testWe); i >= 0 {
325		t.Errorf("we disagrees at index %v; %v != %v", i, we[i], testWe[i])
326	}
327	if i := compareFloat32Slices(fe[0:], testFe); i >= 0 {
328		t.Errorf("fe disagrees at index %v; %v != %v", i, fe[i], testFe[i])
329	}
330}
331
332func hasSlowFloatingPoint() bool {
333	switch runtime.GOARCH {
334	case "arm":
335		return os.Getenv("GOARM") == "5" || strings.HasSuffix(os.Getenv("GOARM"), ",softfloat")
336	case "mips", "mipsle", "mips64", "mips64le":
337		// Be conservative and assume that all mips boards
338		// have emulated floating point.
339		// TODO: detect what it actually has.
340		return true
341	}
342	return false
343}
344
345func TestFloat32(t *testing.T) {
346	// For issue 6721, the problem came after 7533753 calls, so check 10e6.
347	num := int(10e6)
348	// But do the full amount only on builders (not locally).
349	// But ARM5 floating point emulation is slow (Issue 10749), so
350	// do less for that builder:
351	if testing.Short() && (testenv.Builder() == "" || hasSlowFloatingPoint()) {
352		num /= 100 // 1.72 seconds instead of 172 seconds
353	}
354
355	r := New(NewSource(1))
356	for ct := 0; ct < num; ct++ {
357		f := r.Float32()
358		if f >= 1 {
359			t.Fatal("Float32() should be in range [0,1). ct:", ct, "f:", f)
360		}
361	}
362}
363
364func testReadUniformity(t *testing.T, n int, seed int64) {
365	r := New(NewSource(seed))
366	buf := make([]byte, n)
367	nRead, err := r.Read(buf)
368	if err != nil {
369		t.Errorf("Read err %v", err)
370	}
371	if nRead != n {
372		t.Errorf("Read returned unexpected n; %d != %d", nRead, n)
373	}
374
375	// Expect a uniform distribution of byte values, which lie in [0, 255].
376	var (
377		mean       = 255.0 / 2
378		stddev     = 256.0 / math.Sqrt(12.0)
379		errorScale = stddev / math.Sqrt(float64(n))
380	)
381
382	expected := &statsResults{mean, stddev, 0.10 * errorScale, 0.08 * errorScale}
383
384	// Cast bytes as floats to use the common distribution-validity checks.
385	samples := make([]float64, n)
386	for i, val := range buf {
387		samples[i] = float64(val)
388	}
389	// Make sure that the entire set matches the expected distribution.
390	checkSampleDistribution(t, samples, expected)
391}
392
393func TestReadUniformity(t *testing.T) {
394	testBufferSizes := []int{
395		2, 4, 7, 64, 1024, 1 << 16, 1 << 20,
396	}
397	for _, seed := range testSeeds {
398		for _, n := range testBufferSizes {
399			testReadUniformity(t, n, seed)
400		}
401	}
402}
403
404func TestReadEmpty(t *testing.T) {
405	r := New(NewSource(1))
406	buf := make([]byte, 0)
407	n, err := r.Read(buf)
408	if err != nil {
409		t.Errorf("Read err into empty buffer; %v", err)
410	}
411	if n != 0 {
412		t.Errorf("Read into empty buffer returned unexpected n of %d", n)
413	}
414}
415
416func TestReadByOneByte(t *testing.T) {
417	r := New(NewSource(1))
418	b1 := make([]byte, 100)
419	_, err := io.ReadFull(iotest.OneByteReader(r), b1)
420	if err != nil {
421		t.Errorf("read by one byte: %v", err)
422	}
423	r = New(NewSource(1))
424	b2 := make([]byte, 100)
425	_, err = r.Read(b2)
426	if err != nil {
427		t.Errorf("read: %v", err)
428	}
429	if !bytes.Equal(b1, b2) {
430		t.Errorf("read by one byte vs single read:\n%x\n%x", b1, b2)
431	}
432}
433
434func TestReadSeedReset(t *testing.T) {
435	r := New(NewSource(42))
436	b1 := make([]byte, 128)
437	_, err := r.Read(b1)
438	if err != nil {
439		t.Errorf("read: %v", err)
440	}
441	r.Seed(42)
442	b2 := make([]byte, 128)
443	_, err = r.Read(b2)
444	if err != nil {
445		t.Errorf("read: %v", err)
446	}
447	if !bytes.Equal(b1, b2) {
448		t.Errorf("mismatch after re-seed:\n%x\n%x", b1, b2)
449	}
450}
451
452func TestShuffleSmall(t *testing.T) {
453	// Check that Shuffle allows n=0 and n=1, but that swap is never called for them.
454	r := New(NewSource(1))
455	for n := 0; n <= 1; n++ {
456		r.Shuffle(n, func(i, j int) { t.Fatalf("swap called, n=%d i=%d j=%d", n, i, j) })
457	}
458}
459
460// encodePerm converts from a permuted slice of length n, such as Perm generates, to an int in [0, n!).
461// See https://en.wikipedia.org/wiki/Lehmer_code.
462// encodePerm modifies the input slice.
463func encodePerm(s []int) int {
464	// Convert to Lehmer code.
465	for i, x := range s {
466		r := s[i+1:]
467		for j, y := range r {
468			if y > x {
469				r[j]--
470			}
471		}
472	}
473	// Convert to int in [0, n!).
474	m := 0
475	fact := 1
476	for i := len(s) - 1; i >= 0; i-- {
477		m += s[i] * fact
478		fact *= len(s) - i
479	}
480	return m
481}
482
483// TestUniformFactorial tests several ways of generating a uniform value in [0, n!).
484func TestUniformFactorial(t *testing.T) {
485	r := New(NewSource(testSeeds[0]))
486	top := 6
487	if testing.Short() {
488		top = 3
489	}
490	for n := 3; n <= top; n++ {
491		t.Run(fmt.Sprintf("n=%d", n), func(t *testing.T) {
492			// Calculate n!.
493			nfact := 1
494			for i := 2; i <= n; i++ {
495				nfact *= i
496			}
497
498			// Test a few different ways to generate a uniform distribution.
499			p := make([]int, n) // re-usable slice for Shuffle generator
500			tests := [...]struct {
501				name string
502				fn   func() int
503			}{
504				{name: "Int31n", fn: func() int { return int(r.Int31n(int32(nfact))) }},
505				{name: "int31n", fn: func() int { return int(Int31nForTest(r, int32(nfact))) }},
506				{name: "Perm", fn: func() int { return encodePerm(r.Perm(n)) }},
507				{name: "Shuffle", fn: func() int {
508					// Generate permutation using Shuffle.
509					for i := range p {
510						p[i] = i
511					}
512					r.Shuffle(n, func(i, j int) { p[i], p[j] = p[j], p[i] })
513					return encodePerm(p)
514				}},
515			}
516
517			for _, test := range tests {
518				t.Run(test.name, func(t *testing.T) {
519					// Gather chi-squared values and check that they follow
520					// the expected normal distribution given n!-1 degrees of freedom.
521					// See https://en.wikipedia.org/wiki/Pearson%27s_chi-squared_test and
522					// https://www.johndcook.com/Beautiful_Testing_ch10.pdf.
523					nsamples := 10 * nfact
524					if nsamples < 200 {
525						nsamples = 200
526					}
527					samples := make([]float64, nsamples)
528					for i := range samples {
529						// Generate some uniformly distributed values and count their occurrences.
530						const iters = 1000
531						counts := make([]int, nfact)
532						for i := 0; i < iters; i++ {
533							counts[test.fn()]++
534						}
535						// Calculate chi-squared and add to samples.
536						want := iters / float64(nfact)
537						var χ2 float64
538						for _, have := range counts {
539							err := float64(have) - want
540							χ2 += err * err
541						}
542						χ2 /= want
543						samples[i] = χ2
544					}
545
546					// Check that our samples approximate the appropriate normal distribution.
547					dof := float64(nfact - 1)
548					expected := &statsResults{mean: dof, stddev: math.Sqrt(2 * dof)}
549					errorScale := max(1.0, expected.stddev)
550					expected.closeEnough = 0.10 * errorScale
551					expected.maxError = 0.08 // TODO: What is the right value here? See issue 21211.
552					checkSampleDistribution(t, samples, expected)
553				})
554			}
555		})
556	}
557}
558
559// Benchmarks
560
561func BenchmarkInt63Threadsafe(b *testing.B) {
562	for n := b.N; n > 0; n-- {
563		Int63()
564	}
565}
566
567func BenchmarkInt63ThreadsafeParallel(b *testing.B) {
568	b.RunParallel(func(pb *testing.PB) {
569		for pb.Next() {
570			Int63()
571		}
572	})
573}
574
575func BenchmarkInt63Unthreadsafe(b *testing.B) {
576	r := New(NewSource(1))
577	for n := b.N; n > 0; n-- {
578		r.Int63()
579	}
580}
581
582func BenchmarkIntn1000(b *testing.B) {
583	r := New(NewSource(1))
584	for n := b.N; n > 0; n-- {
585		r.Intn(1000)
586	}
587}
588
589func BenchmarkInt63n1000(b *testing.B) {
590	r := New(NewSource(1))
591	for n := b.N; n > 0; n-- {
592		r.Int63n(1000)
593	}
594}
595
596func BenchmarkInt31n1000(b *testing.B) {
597	r := New(NewSource(1))
598	for n := b.N; n > 0; n-- {
599		r.Int31n(1000)
600	}
601}
602
603func BenchmarkFloat32(b *testing.B) {
604	r := New(NewSource(1))
605	for n := b.N; n > 0; n-- {
606		r.Float32()
607	}
608}
609
610func BenchmarkFloat64(b *testing.B) {
611	r := New(NewSource(1))
612	for n := b.N; n > 0; n-- {
613		r.Float64()
614	}
615}
616
617func BenchmarkPerm3(b *testing.B) {
618	r := New(NewSource(1))
619	for n := b.N; n > 0; n-- {
620		r.Perm(3)
621	}
622}
623
624func BenchmarkPerm30(b *testing.B) {
625	r := New(NewSource(1))
626	for n := b.N; n > 0; n-- {
627		r.Perm(30)
628	}
629}
630
631func BenchmarkPerm30ViaShuffle(b *testing.B) {
632	r := New(NewSource(1))
633	for n := b.N; n > 0; n-- {
634		p := make([]int, 30)
635		for i := range p {
636			p[i] = i
637		}
638		r.Shuffle(30, func(i, j int) { p[i], p[j] = p[j], p[i] })
639	}
640}
641
642// BenchmarkShuffleOverhead uses a minimal swap function
643// to measure just the shuffling overhead.
644func BenchmarkShuffleOverhead(b *testing.B) {
645	r := New(NewSource(1))
646	for n := b.N; n > 0; n-- {
647		r.Shuffle(52, func(i, j int) {
648			if i < 0 || i >= 52 || j < 0 || j >= 52 {
649				b.Fatalf("bad swap(%d, %d)", i, j)
650			}
651		})
652	}
653}
654
655func BenchmarkRead3(b *testing.B) {
656	r := New(NewSource(1))
657	buf := make([]byte, 3)
658	b.ResetTimer()
659	for n := b.N; n > 0; n-- {
660		r.Read(buf)
661	}
662}
663
664func BenchmarkRead64(b *testing.B) {
665	r := New(NewSource(1))
666	buf := make([]byte, 64)
667	b.ResetTimer()
668	for n := b.N; n > 0; n-- {
669		r.Read(buf)
670	}
671}
672
673func BenchmarkRead1000(b *testing.B) {
674	r := New(NewSource(1))
675	buf := make([]byte, 1000)
676	b.ResetTimer()
677	for n := b.N; n > 0; n-- {
678		r.Read(buf)
679	}
680}
681
682func BenchmarkConcurrent(b *testing.B) {
683	const goroutines = 4
684	var wg sync.WaitGroup
685	wg.Add(goroutines)
686	for i := 0; i < goroutines; i++ {
687		go func() {
688			defer wg.Done()
689			for n := b.N; n > 0; n-- {
690				Int63()
691			}
692		}()
693	}
694	wg.Wait()
695}
696