1// Copyright (c) 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package edwards25519
6
7import (
8	"bytes"
9	"encoding/hex"
10	"math/big"
11	mathrand "math/rand"
12	"reflect"
13	"testing"
14	"testing/quick"
15)
16
17// quickCheckConfig returns a quick.Config that scales the max count by the
18// given factor if the -short flag is not set.
19func quickCheckConfig(slowScale int) *quick.Config {
20	cfg := new(quick.Config)
21	if !testing.Short() {
22		cfg.MaxCountScale = float64(slowScale)
23	}
24	return cfg
25}
26
27var scOneBytes = [32]byte{1}
28var scOne, _ = new(Scalar).SetCanonicalBytes(scOneBytes[:])
29var scMinusOne, _ = new(Scalar).SetCanonicalBytes(scalarMinusOneBytes[:])
30
31// Generate returns a valid (reduced modulo l) Scalar with a distribution
32// weighted towards high, low, and edge values.
33func (Scalar) Generate(rand *mathrand.Rand, size int) reflect.Value {
34	var s [32]byte
35	diceRoll := rand.Intn(100)
36	switch {
37	case diceRoll == 0:
38	case diceRoll == 1:
39		s = scOneBytes
40	case diceRoll == 2:
41		s = scalarMinusOneBytes
42	case diceRoll < 5:
43		// Generate a low scalar in [0, 2^125).
44		rand.Read(s[:16])
45		s[15] &= (1 << 5) - 1
46	case diceRoll < 10:
47		// Generate a high scalar in [2^252, 2^252 + 2^124).
48		s[31] = 1 << 4
49		rand.Read(s[:16])
50		s[15] &= (1 << 4) - 1
51	default:
52		// Generate a valid scalar in [0, l) by returning [0, 2^252) which has a
53		// negligibly different distribution (the former has a 2^-127.6 chance
54		// of being out of the latter range).
55		rand.Read(s[:])
56		s[31] &= (1 << 4) - 1
57	}
58
59	val := Scalar{}
60	fiatScalarFromBytes((*[4]uint64)(&val.s), &s)
61	fiatScalarToMontgomery(&val.s, (*fiatScalarNonMontgomeryDomainFieldElement)(&val.s))
62
63	return reflect.ValueOf(val)
64}
65
66func TestScalarGenerate(t *testing.T) {
67	f := func(sc Scalar) bool {
68		return isReduced(sc.Bytes())
69	}
70	if err := quick.Check(f, quickCheckConfig(1024)); err != nil {
71		t.Errorf("generated unreduced scalar: %v", err)
72	}
73}
74
75func TestScalarSetCanonicalBytes(t *testing.T) {
76	f1 := func(in [32]byte, sc Scalar) bool {
77		// Mask out top 4 bits to guarantee value falls in [0, l).
78		in[len(in)-1] &= (1 << 4) - 1
79		if _, err := sc.SetCanonicalBytes(in[:]); err != nil {
80			return false
81		}
82		repr := sc.Bytes()
83		return bytes.Equal(in[:], repr) && isReduced(repr)
84	}
85	if err := quick.Check(f1, quickCheckConfig(1024)); err != nil {
86		t.Errorf("failed bytes->scalar->bytes round-trip: %v", err)
87	}
88
89	f2 := func(sc1, sc2 Scalar) bool {
90		if _, err := sc2.SetCanonicalBytes(sc1.Bytes()); err != nil {
91			return false
92		}
93		return sc1 == sc2
94	}
95	if err := quick.Check(f2, quickCheckConfig(1024)); err != nil {
96		t.Errorf("failed scalar->bytes->scalar round-trip: %v", err)
97	}
98
99	b := scalarMinusOneBytes
100	b[31] += 1
101	s := scOne
102	if out, err := s.SetCanonicalBytes(b[:]); err == nil {
103		t.Errorf("SetCanonicalBytes worked on a non-canonical value")
104	} else if s != scOne {
105		t.Errorf("SetCanonicalBytes modified its receiver")
106	} else if out != nil {
107		t.Errorf("SetCanonicalBytes did not return nil with an error")
108	}
109}
110
111func TestScalarSetUniformBytes(t *testing.T) {
112	mod, _ := new(big.Int).SetString("27742317777372353535851937790883648493", 10)
113	mod.Add(mod, new(big.Int).Lsh(big.NewInt(1), 252))
114	f := func(in [64]byte, sc Scalar) bool {
115		sc.SetUniformBytes(in[:])
116		repr := sc.Bytes()
117		if !isReduced(repr) {
118			return false
119		}
120		scBig := bigIntFromLittleEndianBytes(repr[:])
121		inBig := bigIntFromLittleEndianBytes(in[:])
122		return inBig.Mod(inBig, mod).Cmp(scBig) == 0
123	}
124	if err := quick.Check(f, quickCheckConfig(1024)); err != nil {
125		t.Error(err)
126	}
127}
128
129func TestScalarSetBytesWithClamping(t *testing.T) {
130	// Generated with libsodium.js 1.0.18 crypto_scalarmult_ed25519_base.
131
132	random := "633d368491364dc9cd4c1bf891b1d59460face1644813240a313e61f2c88216e"
133	s, _ := new(Scalar).SetBytesWithClamping(decodeHex(random))
134	p := new(Point).ScalarBaseMult(s)
135	want := "1d87a9026fd0126a5736fe1628c95dd419172b5b618457e041c9c861b2494a94"
136	if got := hex.EncodeToString(p.Bytes()); got != want {
137		t.Errorf("random: got %q, want %q", got, want)
138	}
139
140	zero := "0000000000000000000000000000000000000000000000000000000000000000"
141	s, _ = new(Scalar).SetBytesWithClamping(decodeHex(zero))
142	p = new(Point).ScalarBaseMult(s)
143	want = "693e47972caf527c7883ad1b39822f026f47db2ab0e1919955b8993aa04411d1"
144	if got := hex.EncodeToString(p.Bytes()); got != want {
145		t.Errorf("zero: got %q, want %q", got, want)
146	}
147
148	one := "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
149	s, _ = new(Scalar).SetBytesWithClamping(decodeHex(one))
150	p = new(Point).ScalarBaseMult(s)
151	want = "12e9a68b73fd5aacdbcaf3e88c46fea6ebedb1aa84eed1842f07f8edab65e3a7"
152	if got := hex.EncodeToString(p.Bytes()); got != want {
153		t.Errorf("one: got %q, want %q", got, want)
154	}
155}
156
157func bigIntFromLittleEndianBytes(b []byte) *big.Int {
158	bb := make([]byte, len(b))
159	for i := range b {
160		bb[i] = b[len(b)-i-1]
161	}
162	return new(big.Int).SetBytes(bb)
163}
164
165func TestScalarMultiplyDistributesOverAdd(t *testing.T) {
166	multiplyDistributesOverAdd := func(x, y, z Scalar) bool {
167		// Compute t1 = (x+y)*z
168		var t1 Scalar
169		t1.Add(&x, &y)
170		t1.Multiply(&t1, &z)
171
172		// Compute t2 = x*z + y*z
173		var t2 Scalar
174		var t3 Scalar
175		t2.Multiply(&x, &z)
176		t3.Multiply(&y, &z)
177		t2.Add(&t2, &t3)
178
179		reprT1, reprT2 := t1.Bytes(), t2.Bytes()
180
181		return t1 == t2 && isReduced(reprT1) && isReduced(reprT2)
182	}
183
184	if err := quick.Check(multiplyDistributesOverAdd, quickCheckConfig(1024)); err != nil {
185		t.Error(err)
186	}
187}
188
189func TestScalarAddLikeSubNeg(t *testing.T) {
190	addLikeSubNeg := func(x, y Scalar) bool {
191		// Compute t1 = x - y
192		var t1 Scalar
193		t1.Subtract(&x, &y)
194
195		// Compute t2 = -y + x
196		var t2 Scalar
197		t2.Negate(&y)
198		t2.Add(&t2, &x)
199
200		return t1 == t2 && isReduced(t1.Bytes())
201	}
202
203	if err := quick.Check(addLikeSubNeg, quickCheckConfig(1024)); err != nil {
204		t.Error(err)
205	}
206}
207
208func TestScalarNonAdjacentForm(t *testing.T) {
209	s, _ := (&Scalar{}).SetCanonicalBytes([]byte{
210		0x1a, 0x0e, 0x97, 0x8a, 0x90, 0xf6, 0x62, 0x2d,
211		0x37, 0x47, 0x02, 0x3f, 0x8a, 0xd8, 0x26, 0x4d,
212		0xa7, 0x58, 0xaa, 0x1b, 0x88, 0xe0, 0x40, 0xd1,
213		0x58, 0x9e, 0x7b, 0x7f, 0x23, 0x76, 0xef, 0x09,
214	})
215
216	expectedNaf := [256]int8{
217		0, 13, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, -9, 0, 0, 0, 0, -11, 0, 0, 0, 0, 3, 0, 0, 0, 0, 1,
218		0, 0, 0, 0, 9, 0, 0, 0, 0, -5, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 11, 0, 0, 0, 0, 11, 0, 0, 0, 0, 0,
219		-9, 0, 0, 0, 0, 0, -3, 0, 0, 0, 0, 9, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 9, 0,
220		0, 0, 0, -15, 0, 0, 0, 0, -7, 0, 0, 0, 0, -9, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 13, 0, 0, 0, 0, 0, -3, 0,
221		0, 0, 0, -11, 0, 0, 0, 0, -7, 0, 0, 0, 0, -13, 0, 0, 0, 0, 11, 0, 0, 0, 0, -9, 0, 0, 0, 0, 0, 1, 0, 0,
222		0, 0, 0, -15, 0, 0, 0, 0, 1, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 13, 0, 0, 0,
223		0, 0, 0, 11, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, -9, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 7,
224		0, 0, 0, 0, 0, -15, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 15, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
225	}
226
227	sNaf := s.nonAdjacentForm(5)
228
229	for i := 0; i < 256; i++ {
230		if expectedNaf[i] != sNaf[i] {
231			t.Errorf("Wrong digit at position %d, got %d, expected %d", i, sNaf[i], expectedNaf[i])
232		}
233	}
234}
235
236type notZeroScalar Scalar
237
238func (notZeroScalar) Generate(rand *mathrand.Rand, size int) reflect.Value {
239	var s Scalar
240	var isNonZero uint64
241	for isNonZero == 0 {
242		s = Scalar{}.Generate(rand, size).Interface().(Scalar)
243		fiatScalarNonzero(&isNonZero, (*[4]uint64)(&s.s))
244	}
245	return reflect.ValueOf(notZeroScalar(s))
246}
247
248func TestScalarEqual(t *testing.T) {
249	if scOne.Equal(scMinusOne) == 1 {
250		t.Errorf("scOne.Equal(&scMinusOne) is true")
251	}
252	if scMinusOne.Equal(scMinusOne) == 0 {
253		t.Errorf("scMinusOne.Equal(&scMinusOne) is false")
254	}
255}
256