1// Copyright 2021 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package bigmod
6
7import (
8	"fmt"
9	"math/big"
10	"math/bits"
11	"math/rand"
12	"reflect"
13	"strings"
14	"testing"
15	"testing/quick"
16)
17
18func (n *Nat) String() string {
19	var limbs []string
20	for i := range n.limbs {
21		limbs = append(limbs, fmt.Sprintf("%016X", n.limbs[len(n.limbs)-1-i]))
22	}
23	return "{" + strings.Join(limbs, " ") + "}"
24}
25
26// Generate generates an even nat. It's used by testing/quick to produce random
27// *nat values for quick.Check invocations.
28func (*Nat) Generate(r *rand.Rand, size int) reflect.Value {
29	limbs := make([]uint, size)
30	for i := 0; i < size; i++ {
31		limbs[i] = uint(r.Uint64()) & ((1 << _W) - 2)
32	}
33	return reflect.ValueOf(&Nat{limbs})
34}
35
36func testModAddCommutative(a *Nat, b *Nat) bool {
37	m := maxModulus(uint(len(a.limbs)))
38	aPlusB := new(Nat).set(a)
39	aPlusB.Add(b, m)
40	bPlusA := new(Nat).set(b)
41	bPlusA.Add(a, m)
42	return aPlusB.Equal(bPlusA) == 1
43}
44
45func TestModAddCommutative(t *testing.T) {
46	err := quick.Check(testModAddCommutative, &quick.Config{})
47	if err != nil {
48		t.Error(err)
49	}
50}
51
52func testModSubThenAddIdentity(a *Nat, b *Nat) bool {
53	m := maxModulus(uint(len(a.limbs)))
54	original := new(Nat).set(a)
55	a.Sub(b, m)
56	a.Add(b, m)
57	return a.Equal(original) == 1
58}
59
60func TestModSubThenAddIdentity(t *testing.T) {
61	err := quick.Check(testModSubThenAddIdentity, &quick.Config{})
62	if err != nil {
63		t.Error(err)
64	}
65}
66
67func TestMontgomeryRoundtrip(t *testing.T) {
68	err := quick.Check(func(a *Nat) bool {
69		one := &Nat{make([]uint, len(a.limbs))}
70		one.limbs[0] = 1
71		aPlusOne := new(big.Int).SetBytes(natBytes(a))
72		aPlusOne.Add(aPlusOne, big.NewInt(1))
73		m, _ := NewModulusFromBig(aPlusOne)
74		monty := new(Nat).set(a)
75		monty.montgomeryRepresentation(m)
76		aAgain := new(Nat).set(monty)
77		aAgain.montgomeryMul(monty, one, m)
78		if a.Equal(aAgain) != 1 {
79			t.Errorf("%v != %v", a, aAgain)
80			return false
81		}
82		return true
83	}, &quick.Config{})
84	if err != nil {
85		t.Error(err)
86	}
87}
88
89func TestShiftIn(t *testing.T) {
90	if bits.UintSize != 64 {
91		t.Skip("examples are only valid in 64 bit")
92	}
93	examples := []struct {
94		m, x, expected []byte
95		y              uint64
96	}{{
97		m:        []byte{13},
98		x:        []byte{0},
99		y:        0xFFFF_FFFF_FFFF_FFFF,
100		expected: []byte{2},
101	}, {
102		m:        []byte{13},
103		x:        []byte{7},
104		y:        0xFFFF_FFFF_FFFF_FFFF,
105		expected: []byte{10},
106	}, {
107		m:        []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
108		x:        make([]byte, 9),
109		y:        0xFFFF_FFFF_FFFF_FFFF,
110		expected: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
111	}, {
112		m:        []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
113		x:        []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
114		y:        0,
115		expected: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06},
116	}}
117
118	for i, tt := range examples {
119		m := modulusFromBytes(tt.m)
120		got := natFromBytes(tt.x).ExpandFor(m).shiftIn(uint(tt.y), m)
121		if exp := natFromBytes(tt.expected).ExpandFor(m); got.Equal(exp) != 1 {
122			t.Errorf("%d: got %v, expected %v", i, got, exp)
123		}
124	}
125}
126
127func TestModulusAndNatSizes(t *testing.T) {
128	// These are 126 bit (2 * _W on 64-bit architectures) values, serialized as
129	// 128 bits worth of bytes. If leading zeroes are stripped, they fit in two
130	// limbs, if they are not, they fit in three. This can be a problem because
131	// modulus strips leading zeroes and nat does not.
132	m := modulusFromBytes([]byte{
133		0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
134		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
135	xb := []byte{0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
136		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}
137	natFromBytes(xb).ExpandFor(m) // must not panic for shrinking
138	NewNat().SetBytes(xb, m)
139}
140
141func TestSetBytes(t *testing.T) {
142	tests := []struct {
143		m, b []byte
144		fail bool
145	}{{
146		m: []byte{0xff, 0xff},
147		b: []byte{0x00, 0x01},
148	}, {
149		m:    []byte{0xff, 0xff},
150		b:    []byte{0xff, 0xff},
151		fail: true,
152	}, {
153		m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
154		b: []byte{0x00, 0x01},
155	}, {
156		m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
157		b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
158	}, {
159		m:    []byte{0xff, 0xff},
160		b:    []byte{0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
161		fail: true,
162	}, {
163		m:    []byte{0xff, 0xff},
164		b:    []byte{0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
165		fail: true,
166	}, {
167		m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
168		b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
169	}, {
170		m:    []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
171		b:    []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
172		fail: true,
173	}, {
174		m:    []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
175		b:    []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
176		fail: true,
177	}, {
178		m:    []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
179		b:    []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
180		fail: true,
181	}, {
182		m:    []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfd},
183		b:    []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
184		fail: true,
185	}}
186
187	for i, tt := range tests {
188		m := modulusFromBytes(tt.m)
189		got, err := NewNat().SetBytes(tt.b, m)
190		if err != nil {
191			if !tt.fail {
192				t.Errorf("%d: unexpected error: %v", i, err)
193			}
194			continue
195		}
196		if tt.fail {
197			t.Errorf("%d: unexpected success", i)
198			continue
199		}
200		if expected := natFromBytes(tt.b).ExpandFor(m); got.Equal(expected) != yes {
201			t.Errorf("%d: got %v, expected %v", i, got, expected)
202		}
203	}
204
205	f := func(xBytes []byte) bool {
206		m := maxModulus(uint(len(xBytes)*8/_W + 1))
207		got, err := NewNat().SetBytes(xBytes, m)
208		if err != nil {
209			return false
210		}
211		return got.Equal(natFromBytes(xBytes).ExpandFor(m)) == yes
212	}
213
214	err := quick.Check(f, &quick.Config{})
215	if err != nil {
216		t.Error(err)
217	}
218}
219
220func TestExpand(t *testing.T) {
221	sliced := []uint{1, 2, 3, 4}
222	examples := []struct {
223		in  []uint
224		n   int
225		out []uint
226	}{{
227		[]uint{1, 2},
228		4,
229		[]uint{1, 2, 0, 0},
230	}, {
231		sliced[:2],
232		4,
233		[]uint{1, 2, 0, 0},
234	}, {
235		[]uint{1, 2},
236		2,
237		[]uint{1, 2},
238	}}
239
240	for i, tt := range examples {
241		got := (&Nat{tt.in}).expand(tt.n)
242		if len(got.limbs) != len(tt.out) || got.Equal(&Nat{tt.out}) != 1 {
243			t.Errorf("%d: got %v, expected %v", i, got, tt.out)
244		}
245	}
246}
247
248func TestMod(t *testing.T) {
249	m := modulusFromBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d})
250	x := natFromBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
251	out := new(Nat)
252	out.Mod(x, m)
253	expected := natFromBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09})
254	if out.Equal(expected) != 1 {
255		t.Errorf("%+v != %+v", out, expected)
256	}
257}
258
259func TestModSub(t *testing.T) {
260	m := modulusFromBytes([]byte{13})
261	x := &Nat{[]uint{6}}
262	y := &Nat{[]uint{7}}
263	x.Sub(y, m)
264	expected := &Nat{[]uint{12}}
265	if x.Equal(expected) != 1 {
266		t.Errorf("%+v != %+v", x, expected)
267	}
268	x.Sub(y, m)
269	expected = &Nat{[]uint{5}}
270	if x.Equal(expected) != 1 {
271		t.Errorf("%+v != %+v", x, expected)
272	}
273}
274
275func TestModAdd(t *testing.T) {
276	m := modulusFromBytes([]byte{13})
277	x := &Nat{[]uint{6}}
278	y := &Nat{[]uint{7}}
279	x.Add(y, m)
280	expected := &Nat{[]uint{0}}
281	if x.Equal(expected) != 1 {
282		t.Errorf("%+v != %+v", x, expected)
283	}
284	x.Add(y, m)
285	expected = &Nat{[]uint{7}}
286	if x.Equal(expected) != 1 {
287		t.Errorf("%+v != %+v", x, expected)
288	}
289}
290
291func TestExp(t *testing.T) {
292	m := modulusFromBytes([]byte{13})
293	x := &Nat{[]uint{3}}
294	out := &Nat{[]uint{0}}
295	out.Exp(x, []byte{12}, m)
296	expected := &Nat{[]uint{1}}
297	if out.Equal(expected) != 1 {
298		t.Errorf("%+v != %+v", out, expected)
299	}
300}
301
302func TestExpShort(t *testing.T) {
303	m := modulusFromBytes([]byte{13})
304	x := &Nat{[]uint{3}}
305	out := &Nat{[]uint{0}}
306	out.ExpShortVarTime(x, 12, m)
307	expected := &Nat{[]uint{1}}
308	if out.Equal(expected) != 1 {
309		t.Errorf("%+v != %+v", out, expected)
310	}
311}
312
313// TestMulReductions tests that Mul reduces results equal or slightly greater
314// than the modulus. Some Montgomery algorithms don't and need extra care to
315// return correct results. See https://go.dev/issue/13907.
316func TestMulReductions(t *testing.T) {
317	// Two short but multi-limb primes.
318	a, _ := new(big.Int).SetString("773608962677651230850240281261679752031633236267106044359907", 10)
319	b, _ := new(big.Int).SetString("180692823610368451951102211649591374573781973061758082626801", 10)
320	n := new(big.Int).Mul(a, b)
321
322	N, _ := NewModulusFromBig(n)
323	A := NewNat().setBig(a).ExpandFor(N)
324	B := NewNat().setBig(b).ExpandFor(N)
325
326	if A.Mul(B, N).IsZero() != 1 {
327		t.Error("a * b mod (a * b) != 0")
328	}
329
330	i := new(big.Int).ModInverse(a, b)
331	N, _ = NewModulusFromBig(b)
332	A = NewNat().setBig(a).ExpandFor(N)
333	I := NewNat().setBig(i).ExpandFor(N)
334	one := NewNat().setBig(big.NewInt(1)).ExpandFor(N)
335
336	if A.Mul(I, N).Equal(one) != 1 {
337		t.Error("a * inv(a) mod b != 1")
338	}
339}
340
341func natBytes(n *Nat) []byte {
342	return n.Bytes(maxModulus(uint(len(n.limbs))))
343}
344
345func natFromBytes(b []byte) *Nat {
346	// Must not use Nat.SetBytes as it's used in TestSetBytes.
347	bb := new(big.Int).SetBytes(b)
348	return NewNat().setBig(bb)
349}
350
351func modulusFromBytes(b []byte) *Modulus {
352	bb := new(big.Int).SetBytes(b)
353	m, _ := NewModulusFromBig(bb)
354	return m
355}
356
357// maxModulus returns the biggest modulus that can fit in n limbs.
358func maxModulus(n uint) *Modulus {
359	b := big.NewInt(1)
360	b.Lsh(b, n*_W)
361	b.Sub(b, big.NewInt(1))
362	m, _ := NewModulusFromBig(b)
363	return m
364}
365
366func makeBenchmarkModulus() *Modulus {
367	return maxModulus(32)
368}
369
370func makeBenchmarkValue() *Nat {
371	x := make([]uint, 32)
372	for i := 0; i < 32; i++ {
373		x[i]--
374	}
375	return &Nat{limbs: x}
376}
377
378func makeBenchmarkExponent() []byte {
379	e := make([]byte, 256)
380	for i := 0; i < 32; i++ {
381		e[i] = 0xFF
382	}
383	return e
384}
385
386func BenchmarkModAdd(b *testing.B) {
387	x := makeBenchmarkValue()
388	y := makeBenchmarkValue()
389	m := makeBenchmarkModulus()
390
391	b.ResetTimer()
392	for i := 0; i < b.N; i++ {
393		x.Add(y, m)
394	}
395}
396
397func BenchmarkModSub(b *testing.B) {
398	x := makeBenchmarkValue()
399	y := makeBenchmarkValue()
400	m := makeBenchmarkModulus()
401
402	b.ResetTimer()
403	for i := 0; i < b.N; i++ {
404		x.Sub(y, m)
405	}
406}
407
408func BenchmarkMontgomeryRepr(b *testing.B) {
409	x := makeBenchmarkValue()
410	m := makeBenchmarkModulus()
411
412	b.ResetTimer()
413	for i := 0; i < b.N; i++ {
414		x.montgomeryRepresentation(m)
415	}
416}
417
418func BenchmarkMontgomeryMul(b *testing.B) {
419	x := makeBenchmarkValue()
420	y := makeBenchmarkValue()
421	out := makeBenchmarkValue()
422	m := makeBenchmarkModulus()
423
424	b.ResetTimer()
425	for i := 0; i < b.N; i++ {
426		out.montgomeryMul(x, y, m)
427	}
428}
429
430func BenchmarkModMul(b *testing.B) {
431	x := makeBenchmarkValue()
432	y := makeBenchmarkValue()
433	m := makeBenchmarkModulus()
434
435	b.ResetTimer()
436	for i := 0; i < b.N; i++ {
437		x.Mul(y, m)
438	}
439}
440
441func BenchmarkExpBig(b *testing.B) {
442	out := new(big.Int)
443	exponentBytes := makeBenchmarkExponent()
444	x := new(big.Int).SetBytes(exponentBytes)
445	e := new(big.Int).SetBytes(exponentBytes)
446	n := new(big.Int).SetBytes(exponentBytes)
447	one := new(big.Int).SetUint64(1)
448	n.Add(n, one)
449
450	b.ResetTimer()
451	for i := 0; i < b.N; i++ {
452		out.Exp(x, e, n)
453	}
454}
455
456func BenchmarkExp(b *testing.B) {
457	x := makeBenchmarkValue()
458	e := makeBenchmarkExponent()
459	out := makeBenchmarkValue()
460	m := makeBenchmarkModulus()
461
462	b.ResetTimer()
463	for i := 0; i < b.N; i++ {
464		out.Exp(x, e, m)
465	}
466}
467
468func TestNewModFromBigZero(t *testing.T) {
469	expected := "modulus must be >= 0"
470	_, err := NewModulusFromBig(big.NewInt(0))
471	if err == nil || err.Error() != expected {
472		t.Errorf("NewModulusFromBig(0) got %q, want %q", err, expected)
473	}
474
475	expected = "modulus must be odd"
476	_, err = NewModulusFromBig(big.NewInt(2))
477	if err == nil || err.Error() != expected {
478		t.Errorf("NewModulusFromBig(2) got %q, want %q", err, expected)
479	}
480}
481