1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package rand
6
7import (
8	"errors"
9	"internal/byteorder"
10	"internal/chacha8rand"
11)
12
13// A ChaCha8 is a ChaCha8-based cryptographically strong
14// random number generator.
15type ChaCha8 struct {
16	state chacha8rand.State
17
18	// The last readLen bytes of readBuf are still to be consumed by Read.
19	readBuf [8]byte
20	readLen int // 0 <= readLen <= 8
21}
22
23// NewChaCha8 returns a new ChaCha8 seeded with the given seed.
24func NewChaCha8(seed [32]byte) *ChaCha8 {
25	c := new(ChaCha8)
26	c.state.Init(seed)
27	return c
28}
29
30// Seed resets the ChaCha8 to behave the same way as NewChaCha8(seed).
31func (c *ChaCha8) Seed(seed [32]byte) {
32	c.state.Init(seed)
33	c.readLen = 0
34	c.readBuf = [8]byte{}
35}
36
37// Uint64 returns a uniformly distributed random uint64 value.
38func (c *ChaCha8) Uint64() uint64 {
39	for {
40		x, ok := c.state.Next()
41		if ok {
42			return x
43		}
44		c.state.Refill()
45	}
46}
47
48// Read reads exactly len(p) bytes into p.
49// It always returns len(p) and a nil error.
50//
51// If calls to Read and Uint64 are interleaved, the order in which bits are
52// returned by the two is undefined, and Read may return bits generated before
53// the last call to Uint64.
54func (c *ChaCha8) Read(p []byte) (n int, err error) {
55	if c.readLen > 0 {
56		n = copy(p, c.readBuf[len(c.readBuf)-c.readLen:])
57		c.readLen -= n
58		p = p[n:]
59	}
60	for len(p) >= 8 {
61		byteorder.LePutUint64(p, c.Uint64())
62		p = p[8:]
63		n += 8
64	}
65	if len(p) > 0 {
66		byteorder.LePutUint64(c.readBuf[:], c.Uint64())
67		n += copy(p, c.readBuf[:])
68		c.readLen = 8 - len(p)
69	}
70	return
71}
72
73// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
74func (c *ChaCha8) UnmarshalBinary(data []byte) error {
75	data, ok := cutPrefix(data, []byte("readbuf:"))
76	if ok {
77		var buf []byte
78		buf, data, ok = readUint8LengthPrefixed(data)
79		if !ok {
80			return errors.New("invalid ChaCha8 Read buffer encoding")
81		}
82		c.readLen = copy(c.readBuf[len(c.readBuf)-len(buf):], buf)
83	}
84	return chacha8rand.Unmarshal(&c.state, data)
85}
86
87func cutPrefix(s, prefix []byte) (after []byte, found bool) {
88	if len(s) < len(prefix) || string(s[:len(prefix)]) != string(prefix) {
89		return s, false
90	}
91	return s[len(prefix):], true
92}
93
94func readUint8LengthPrefixed(b []byte) (buf, rest []byte, ok bool) {
95	if len(b) == 0 || len(b) < int(1+b[0]) {
96		return nil, nil, false
97	}
98	return b[1 : 1+b[0]], b[1+b[0]:], true
99}
100
101// MarshalBinary implements the encoding.BinaryMarshaler interface.
102func (c *ChaCha8) MarshalBinary() ([]byte, error) {
103	if c.readLen > 0 {
104		out := []byte("readbuf:")
105		out = append(out, uint8(c.readLen))
106		out = append(out, c.readBuf[len(c.readBuf)-c.readLen:]...)
107		return append(out, chacha8rand.Marshal(&c.state)...), nil
108	}
109	return chacha8rand.Marshal(&c.state), nil
110}
111