1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package mlkem768
6
7import (
8	"bytes"
9	"crypto/rand"
10	_ "embed"
11	"encoding/hex"
12	"errors"
13	"flag"
14	"math/big"
15	"strconv"
16	"testing"
17
18	"golang.org/x/crypto/sha3"
19)
20
21func TestFieldReduce(t *testing.T) {
22	for a := uint32(0); a < 2*q*q; a++ {
23		got := fieldReduce(a)
24		exp := fieldElement(a % q)
25		if got != exp {
26			t.Fatalf("reduce(%d) = %d, expected %d", a, got, exp)
27		}
28	}
29}
30
31func TestFieldAdd(t *testing.T) {
32	for a := fieldElement(0); a < q; a++ {
33		for b := fieldElement(0); b < q; b++ {
34			got := fieldAdd(a, b)
35			exp := (a + b) % q
36			if got != exp {
37				t.Fatalf("%d + %d = %d, expected %d", a, b, got, exp)
38			}
39		}
40	}
41}
42
43func TestFieldSub(t *testing.T) {
44	for a := fieldElement(0); a < q; a++ {
45		for b := fieldElement(0); b < q; b++ {
46			got := fieldSub(a, b)
47			exp := (a - b + q) % q
48			if got != exp {
49				t.Fatalf("%d - %d = %d, expected %d", a, b, got, exp)
50			}
51		}
52	}
53}
54
55func TestFieldMul(t *testing.T) {
56	for a := fieldElement(0); a < q; a++ {
57		for b := fieldElement(0); b < q; b++ {
58			got := fieldMul(a, b)
59			exp := fieldElement((uint32(a) * uint32(b)) % q)
60			if got != exp {
61				t.Fatalf("%d * %d = %d, expected %d", a, b, got, exp)
62			}
63		}
64	}
65}
66
67func TestDecompressCompress(t *testing.T) {
68	for _, bits := range []uint8{1, 4, 10} {
69		for a := uint16(0); a < 1<<bits; a++ {
70			f := decompress(a, bits)
71			if f >= q {
72				t.Fatalf("decompress(%d, %d) = %d >= q", a, bits, f)
73			}
74			got := compress(f, bits)
75			if got != a {
76				t.Fatalf("compress(decompress(%d, %d), %d) = %d", a, bits, bits, got)
77			}
78		}
79
80		for a := fieldElement(0); a < q; a++ {
81			c := compress(a, bits)
82			if c >= 1<<bits {
83				t.Fatalf("compress(%d, %d) = %d >= 2^bits", a, bits, c)
84			}
85			got := decompress(c, bits)
86			diff := min(a-got, got-a, a-got+q, got-a+q)
87			ceil := q / (1 << bits)
88			if diff > fieldElement(ceil) {
89				t.Fatalf("decompress(compress(%d, %d), %d) = %d (diff %d, max diff %d)",
90					a, bits, bits, got, diff, ceil)
91			}
92		}
93	}
94}
95
96func CompressRat(x fieldElement, d uint8) uint16 {
97	if x >= q {
98		panic("x out of range")
99	}
100	if d <= 0 || d >= 12 {
101		panic("d out of range")
102	}
103
104	precise := big.NewRat((1<<d)*int64(x), q) // (2ᵈ / q) * x == (2ᵈ * x) / q
105
106	// FloatString rounds halves away from 0, and our result should always be positive,
107	// so it should work as we expect. (There's no direct way to round a Rat.)
108	rounded, err := strconv.ParseInt(precise.FloatString(0), 10, 64)
109	if err != nil {
110		panic(err)
111	}
112
113	// If we rounded up, `rounded` may be equal to 2ᵈ, so we perform a final reduction.
114	return uint16(rounded % (1 << d))
115}
116
117func TestCompress(t *testing.T) {
118	for d := 1; d < 12; d++ {
119		for n := 0; n < q; n++ {
120			expected := CompressRat(fieldElement(n), uint8(d))
121			result := compress(fieldElement(n), uint8(d))
122			if result != expected {
123				t.Errorf("compress(%d, %d): got %d, expected %d", n, d, result, expected)
124			}
125		}
126	}
127}
128
129func DecompressRat(y uint16, d uint8) fieldElement {
130	if y >= 1<<d {
131		panic("y out of range")
132	}
133	if d <= 0 || d >= 12 {
134		panic("d out of range")
135	}
136
137	precise := big.NewRat(q*int64(y), 1<<d) // (q / 2ᵈ) * y  ==  (q * y) / 2ᵈ
138
139	// FloatString rounds halves away from 0, and our result should always be positive,
140	// so it should work as we expect. (There's no direct way to round a Rat.)
141	rounded, err := strconv.ParseInt(precise.FloatString(0), 10, 64)
142	if err != nil {
143		panic(err)
144	}
145
146	// If we rounded up, `rounded` may be equal to q, so we perform a final reduction.
147	return fieldElement(rounded % q)
148}
149
150func TestDecompress(t *testing.T) {
151	for d := 1; d < 12; d++ {
152		for n := 0; n < (1 << d); n++ {
153			expected := DecompressRat(uint16(n), uint8(d))
154			result := decompress(uint16(n), uint8(d))
155			if result != expected {
156				t.Errorf("decompress(%d, %d): got %d, expected %d", n, d, result, expected)
157			}
158		}
159	}
160}
161
162func BitRev7(n uint8) uint8 {
163	if n>>7 != 0 {
164		panic("not 7 bits")
165	}
166	var r uint8
167	r |= n >> 6 & 0b0000_0001
168	r |= n >> 4 & 0b0000_0010
169	r |= n >> 2 & 0b0000_0100
170	r |= n /**/ & 0b0000_1000
171	r |= n << 2 & 0b0001_0000
172	r |= n << 4 & 0b0010_0000
173	r |= n << 6 & 0b0100_0000
174	return r
175}
176
177func TestZetas(t *testing.T) {
178	ζ := big.NewInt(17)
179	q := big.NewInt(q)
180	for k, zeta := range zetas {
181		// ζ^BitRev7(k) mod q
182		exp := new(big.Int).Exp(ζ, big.NewInt(int64(BitRev7(uint8(k)))), q)
183		if big.NewInt(int64(zeta)).Cmp(exp) != 0 {
184			t.Errorf("zetas[%d] = %v, expected %v", k, zeta, exp)
185		}
186	}
187}
188
189func TestGammas(t *testing.T) {
190	ζ := big.NewInt(17)
191	q := big.NewInt(q)
192	for k, gamma := range gammas {
193		// ζ^2BitRev7(i)+1
194		exp := new(big.Int).Exp(ζ, big.NewInt(int64(BitRev7(uint8(k)))*2+1), q)
195		if big.NewInt(int64(gamma)).Cmp(exp) != 0 {
196			t.Errorf("gammas[%d] = %v, expected %v", k, gamma, exp)
197		}
198	}
199}
200
201func TestRoundTrip(t *testing.T) {
202	dk, err := GenerateKey()
203	if err != nil {
204		t.Fatal(err)
205	}
206	c, Ke, err := Encapsulate(dk.EncapsulationKey())
207	if err != nil {
208		t.Fatal(err)
209	}
210	Kd, err := Decapsulate(dk, c)
211	if err != nil {
212		t.Fatal(err)
213	}
214	if !bytes.Equal(Ke, Kd) {
215		t.Fail()
216	}
217
218	dk1, err := GenerateKey()
219	if err != nil {
220		t.Fatal(err)
221	}
222	if bytes.Equal(dk.EncapsulationKey(), dk1.EncapsulationKey()) {
223		t.Fail()
224	}
225	if bytes.Equal(dk.Bytes(), dk1.Bytes()) {
226		t.Fail()
227	}
228	if bytes.Equal(dk.Bytes()[EncapsulationKeySize-32:], dk1.Bytes()[EncapsulationKeySize-32:]) {
229		t.Fail()
230	}
231
232	c1, Ke1, err := Encapsulate(dk.EncapsulationKey())
233	if err != nil {
234		t.Fatal(err)
235	}
236	if bytes.Equal(c, c1) {
237		t.Fail()
238	}
239	if bytes.Equal(Ke, Ke1) {
240		t.Fail()
241	}
242}
243
244func TestBadLengths(t *testing.T) {
245	dk, err := GenerateKey()
246	if err != nil {
247		t.Fatal(err)
248	}
249	ek := dk.EncapsulationKey()
250
251	for i := 0; i < len(ek)-1; i++ {
252		if _, _, err := Encapsulate(ek[:i]); err == nil {
253			t.Errorf("expected error for ek length %d", i)
254		}
255	}
256	ekLong := ek
257	for i := 0; i < 100; i++ {
258		ekLong = append(ekLong, 0)
259		if _, _, err := Encapsulate(ekLong); err == nil {
260			t.Errorf("expected error for ek length %d", len(ekLong))
261		}
262	}
263
264	c, _, err := Encapsulate(ek)
265	if err != nil {
266		t.Fatal(err)
267	}
268
269	for i := 0; i < len(dk.Bytes())-1; i++ {
270		if _, err := NewKeyFromExtendedEncoding(dk.Bytes()[:i]); err == nil {
271			t.Errorf("expected error for dk length %d", i)
272		}
273	}
274	dkLong := dk.Bytes()
275	for i := 0; i < 100; i++ {
276		dkLong = append(dkLong, 0)
277		if _, err := NewKeyFromExtendedEncoding(dkLong); err == nil {
278			t.Errorf("expected error for dk length %d", len(dkLong))
279		}
280	}
281
282	for i := 0; i < len(c)-1; i++ {
283		if _, err := Decapsulate(dk, c[:i]); err == nil {
284			t.Errorf("expected error for c length %d", i)
285		}
286	}
287	cLong := c
288	for i := 0; i < 100; i++ {
289		cLong = append(cLong, 0)
290		if _, err := Decapsulate(dk, cLong); err == nil {
291			t.Errorf("expected error for c length %d", len(cLong))
292		}
293	}
294}
295
296func EncapsulateDerand(ek, m []byte) (c, K []byte, err error) {
297	if len(m) != messageSize {
298		return nil, nil, errors.New("bad message length")
299	}
300	return kemEncaps(nil, ek, (*[messageSize]byte)(m))
301}
302
303func DecapsulateFromBytes(dkBytes []byte, c []byte) ([]byte, error) {
304	dk, err := NewKeyFromExtendedEncoding(dkBytes)
305	if err != nil {
306		return nil, err
307	}
308	return Decapsulate(dk, c)
309}
310
311func GenerateKeyDerand(t testing.TB, d, z []byte) ([]byte, *DecapsulationKey) {
312	if len(d) != 32 || len(z) != 32 {
313		t.Fatal("bad length")
314	}
315	dk := kemKeyGen(nil, (*[32]byte)(d), (*[32]byte)(z))
316	return dk.EncapsulationKey(), dk
317}
318
319var millionFlag = flag.Bool("million", false, "run the million vector test")
320
321// TestPQCrystalsAccumulated accumulates the 10k vectors generated by the
322// reference implementation and checks the hash of the result, to avoid checking
323// in 150MB of test vectors.
324func TestPQCrystalsAccumulated(t *testing.T) {
325	n := 10000
326	expected := "f7db260e1137a742e05fe0db9525012812b004d29040a5b606aad3d134b548d3"
327	if testing.Short() {
328		n = 100
329		expected = "8d0c478ead6037897a0da6be21e5399545babf5fc6dd10c061c99b7dee2bf0dc"
330	}
331	if *millionFlag {
332		n = 1000000
333		expected = "70090cc5842aad0ec43d5042c783fae9bc320c047b5dafcb6e134821db02384d"
334	}
335
336	s := sha3.NewShake128()
337	o := sha3.NewShake128()
338	d := make([]byte, 32)
339	z := make([]byte, 32)
340	msg := make([]byte, 32)
341	ct1 := make([]byte, CiphertextSize)
342
343	for i := 0; i < n; i++ {
344		s.Read(d)
345		s.Read(z)
346		ek, dk := GenerateKeyDerand(t, d, z)
347		o.Write(ek)
348		o.Write(dk.Bytes())
349
350		s.Read(msg)
351		ct, k, err := EncapsulateDerand(ek, msg)
352		if err != nil {
353			t.Fatal(err)
354		}
355		o.Write(ct)
356		o.Write(k)
357
358		kk, err := Decapsulate(dk, ct)
359		if err != nil {
360			t.Fatal(err)
361		}
362		if !bytes.Equal(kk, k) {
363			t.Errorf("k: got %x, expected %x", kk, k)
364		}
365
366		s.Read(ct1)
367		k1, err := Decapsulate(dk, ct1)
368		if err != nil {
369			t.Fatal(err)
370		}
371		o.Write(k1)
372	}
373
374	got := hex.EncodeToString(o.Sum(nil))
375	if got != expected {
376		t.Errorf("got %s, expected %s", got, expected)
377	}
378}
379
380var sink byte
381
382func BenchmarkKeyGen(b *testing.B) {
383	var dk DecapsulationKey
384	var d, z [32]byte
385	rand.Read(d[:])
386	rand.Read(z[:])
387	b.ResetTimer()
388	for i := 0; i < b.N; i++ {
389		dk := kemKeyGen(&dk, &d, &z)
390		sink ^= dk.EncapsulationKey()[0]
391	}
392}
393
394func BenchmarkEncaps(b *testing.B) {
395	d := make([]byte, 32)
396	rand.Read(d)
397	z := make([]byte, 32)
398	rand.Read(z)
399	var m [messageSize]byte
400	rand.Read(m[:])
401	ek, _ := GenerateKeyDerand(b, d, z)
402	var c [CiphertextSize]byte
403	b.ResetTimer()
404	for i := 0; i < b.N; i++ {
405		c, K, err := kemEncaps(&c, ek, &m)
406		if err != nil {
407			b.Fatal(err)
408		}
409		sink ^= c[0] ^ K[0]
410	}
411}
412
413func BenchmarkDecaps(b *testing.B) {
414	d := make([]byte, 32)
415	rand.Read(d)
416	z := make([]byte, 32)
417	rand.Read(z)
418	m := make([]byte, 32)
419	rand.Read(m)
420	ek, dk := GenerateKeyDerand(b, d, z)
421	c, _, err := EncapsulateDerand(ek, m)
422	if err != nil {
423		b.Fatal(err)
424	}
425	b.ResetTimer()
426	for i := 0; i < b.N; i++ {
427		K := kemDecaps(dk, (*[CiphertextSize]byte)(c))
428		sink ^= K[0]
429	}
430}
431
432func BenchmarkRoundTrip(b *testing.B) {
433	dk, err := GenerateKey()
434	if err != nil {
435		b.Fatal(err)
436	}
437	ek := dk.EncapsulationKey()
438	c, _, err := Encapsulate(ek)
439	if err != nil {
440		b.Fatal(err)
441	}
442	b.Run("Alice", func(b *testing.B) {
443		for i := 0; i < b.N; i++ {
444			dkS, err := GenerateKey()
445			if err != nil {
446				b.Fatal(err)
447			}
448			ekS := dkS.EncapsulationKey()
449			sink ^= ekS[0]
450
451			Ks, err := Decapsulate(dk, c)
452			if err != nil {
453				b.Fatal(err)
454			}
455			sink ^= Ks[0]
456		}
457	})
458	b.Run("Bob", func(b *testing.B) {
459		for i := 0; i < b.N; i++ {
460			cS, Ks, err := Encapsulate(ek)
461			if err != nil {
462				b.Fatal(err)
463			}
464			sink ^= cS[0] ^ Ks[0]
465		}
466	})
467}
468