xref: /aosp_15_r20/external/boringssl/src/ssl/test/runner/kyber/kyber.go (revision 8fb009dc861624b67b6cdb62ea21f0f22d0c584b)
1/* Copyright (c) 2023, Google Inc.
2 *
3 * Permission to use, copy, modify, and/or distribute this software for any
4 * purpose with or without fee is hereby granted, provided that the above
5 * copyright notice and this permission notice appear in all copies.
6 *
7 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
14
15package kyber
16
17// This code is ported from kyber.c.
18
19import (
20	"crypto/subtle"
21	"golang.org/x/crypto/sha3"
22	"io"
23)
24
25const(
26	CiphertextSize       = 1088
27	PublicKeySize        = 1184
28	PrivateKeySize       = 2400
29)
30
31const (
32	degree               = 256
33	rank                 = 3
34	prime                = 3329
35	log2Prime            = 12
36	halfPrime            = (prime - 1) / 2
37	du                   = 10
38	dv                   = 4
39	inverseDegree        = 3303
40	encodedVectorSize    = log2Prime * degree / 8 * rank
41	compressedVectorSize = du * rank * degree / 8
42	barrettMultiplier    = 5039
43	barrettShift         = 24
44)
45
46func reduceOnce(x uint16) uint16 {
47	if x >= 2*prime {
48		panic("reduce_once: value out of range")
49	}
50	subtracted := x - prime
51	mask := 0 - (subtracted >> 15)
52	return (mask & x) | (^mask & subtracted)
53}
54
55func reduce(x uint32) uint16 {
56	if x >= prime+2*prime*prime {
57		panic("reduce: value out of range")
58	}
59	product := uint64(x) * barrettMultiplier
60	quotient := uint32(product >> barrettShift)
61	remainder := uint32(x) - quotient*prime
62	return reduceOnce(uint16(remainder))
63}
64
65// lt returns 0xff..f if a < b and 0 otherwise
66func lt(a, b uint32) uint32 {
67	return uint32(0 - int32(a^((a^b)|((a-b)^a)))>>31)
68}
69
70// Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
71// numbers close to each other together. The formula used is
72// round(2^|bits|/prime*x) mod 2^|bits|.
73// Uses Barrett reduction to achieve constant time. Since we need both the
74// remainder (for rounding) and the quotient (as the result), we cannot use
75// |reduce| here, but need to do the Barrett reduction directly.
76func compress(x uint16, bits int) uint16 {
77	product := uint32(x) << bits
78	quotient := uint32((uint64(product) * barrettMultiplier) >> barrettShift)
79	remainder := product - quotient*prime
80
81	// Adjust the quotient to round correctly:
82	//   0 <= remainder <= halfPrime round to 0
83	//   halfPrime < remainder <= prime + halfPrime round to 1
84	//   prime + halfPrime < remainder < 2 * prime round to 2
85	quotient += 1 & lt(halfPrime, remainder)
86	quotient += 1 & lt(prime+halfPrime, remainder)
87	return uint16(quotient) & ((1 << bits) - 1)
88}
89
90func decompress(x uint16, bits int) uint16 {
91	product := uint32(x) * prime
92	power := uint32(1) << bits
93	// This is |product| % power, since |power| is a power of 2.
94	remainder := product & (power - 1)
95	// This is |product| / power, since |power| is a power of 2.
96	lower := product >> bits
97	// The rounding logic works since the first half of numbers mod |power| have a
98	// 0 as first bit, and the second half has a 1 as first bit, since |power| is
99	// a power of 2. As a 12 bit number, |remainder| is always positive, so we
100	// will shift in 0s for a right shift.
101	return uint16(lower + (remainder >> (bits - 1)))
102}
103
104type scalar [degree]uint16
105
106func (s *scalar) zero() {
107	for i := range s {
108		s[i] = 0
109	}
110}
111
112// This bit of Python will be referenced in some of the following comments:
113//
114// p = 3329
115//
116// def bitreverse(i):
117//     ret = 0
118//     for n in range(7):
119//         bit = i & 1
120//         ret <<= 1
121//         ret |= bit
122//         i >>= 1
123//     return ret
124
125// kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
126var nttRoots = [128]uint16{
127	1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797,
128	2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333,
129	1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756,
130	1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915,
131	2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
132	2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100,
133	1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789,
134	1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641,
135	1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757,
136	2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886,
137	1722, 1212, 1874, 1029, 2110, 2935, 885, 2154,
138}
139
140func (s *scalar) ntt() {
141	offset := degree
142	for step := 1; step < degree/2; step <<= 1 {
143		offset >>= 1
144		k := 0
145		for i := 0; i < step; i++ {
146			stepRoot := uint32(nttRoots[i+step])
147			for j := k; j < k+offset; j++ {
148				odd := reduce(stepRoot * uint32(s[j+offset]))
149				even := s[j]
150				s[j] = reduceOnce(odd + even)
151				s[j+offset] = reduceOnce(even - odd + prime)
152			}
153			k += 2 * offset
154		}
155	}
156}
157
158// kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
159var inverseNTTRoots = [128]uint16{
160	1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543,
161	2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447, 2794, 1235, 1903,
162	1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855,
163	2681, 1848, 712, 682, 927, 1795, 461, 1891, 2877, 2522, 1894, 1010,
164	1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132,
165	1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607,
166	2443, 554, 1179, 2186, 2303, 2926, 2237, 525, 735, 863, 2768, 1230,
167	2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745,
168	2688, 3061, 992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482,
169	1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096, 48, 667, 1920,
170	2229, 1041, 2606, 1692, 680, 2746, 568, 3312,
171}
172
173func (s *scalar) inverseNTT() {
174	step := degree / 2
175	for offset := 2; offset < degree; offset <<= 1 {
176		step >>= 1
177		k := 0
178		for i := 0; i < step; i++ {
179			stepRoot := uint32(inverseNTTRoots[i+step])
180			for j := k; j < k+offset; j++ {
181				odd := s[j+offset]
182				even := s[j]
183				s[j] = reduceOnce(odd + even)
184				s[j+offset] = reduce(stepRoot * uint32(even-odd+prime))
185			}
186			k += 2 * offset
187		}
188	}
189	for i := range s {
190		s[i] = reduce(uint32(s[i]) * inverseDegree)
191	}
192}
193
194func (s *scalar) add(b *scalar) {
195	for i := range s {
196		s[i] = reduceOnce(s[i] + b[i])
197	}
198}
199
200func (s *scalar) sub(b *scalar) {
201	for i := range s {
202		s[i] = reduceOnce(s[i] - b[i] + prime)
203	}
204}
205
206// kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
207var modRoots = [128]uint16{
208	17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606,
209	2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096,
210	756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678,
211	2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642,
212	939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992,
213	268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
214	375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010,
215	2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735,
216	2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179,
217	2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
218	2110, 1219, 2935, 394, 885, 2444, 2154, 1175,
219}
220
221func (s *scalar) mult(a, b *scalar) {
222	for i := 0; i < degree/2; i++ {
223		realReal := uint32(a[2*i]) * uint32(b[2*i])
224		imgImg := uint32(a[2*i+1]) * uint32(b[2*i+1])
225		realImg := uint32(a[2*i]) * uint32(b[2*i+1])
226		imgReal := uint32(a[2*i+1]) * uint32(b[2*i])
227		s[2*i] = reduce(realReal + uint32(reduce(imgImg))*uint32(modRoots[i]))
228		s[2*i+1] = reduce(imgReal + realImg)
229	}
230}
231
232func (s *scalar) innerProduct(left, right *vector) {
233	s.zero()
234	var product scalar
235	for i := range left {
236		product.mult(&left[i], &right[i])
237		s.add(&product)
238	}
239}
240
241func (s *scalar) fromKeccakVartime(keccak io.Reader) {
242	var buf [3]byte
243	for i := 0; i < len(s); {
244		keccak.Read(buf[:])
245		d1 := uint16(buf[0]) + 256*uint16(buf[1]%16)
246		d2 := uint16(buf[1])/16 + 16*uint16(buf[2])
247		if d1 < prime {
248			s[i] = d1
249			i++
250		}
251		if d2 < prime && i < len(s) {
252			s[i] = d2
253			i++
254		}
255	}
256}
257
258func (s *scalar) centeredBinomialEta2(input *[33]byte) {
259	var entropy [128]byte
260	sha3.ShakeSum256(entropy[:], input[:])
261
262	for i := 0; i < len(s); i += 2 {
263		b := uint16(entropy[i/2])
264
265		value := uint16(prime)
266		value += (b & 1) + ((b >> 1) & 1)
267		value -= ((b >> 2) & 1) + ((b >> 3) & 1)
268		s[i] = reduceOnce(value)
269
270		b >>= 4
271		value = prime
272		value += (b & 1) + ((b >> 1) & 1)
273		value -= ((b >> 2) & 1) + ((b >> 3) & 1)
274		s[i+1] = reduceOnce(value)
275	}
276}
277
278var masks = [8]uint16{0x01, 0x03, 0x07, 0x0f, 0x1f, 0x3f, 0x7f, 0xff}
279
280func (s *scalar) encode(out []byte, bits int) []byte {
281	var outByte byte
282	outByteBits := 0
283
284	for i := range s {
285		element := s[i]
286		elementBitsDone := 0
287
288		for elementBitsDone < bits {
289			chunkBits := bits - elementBitsDone
290			outBitsRemaining := 8 - outByteBits
291			if chunkBits >= outBitsRemaining {
292				chunkBits = outBitsRemaining
293				outByte |= byte(element&masks[chunkBits-1]) << outByteBits
294				out[0] = outByte
295				out = out[1:]
296				outByteBits = 0
297				outByte = 0
298			} else {
299				outByte |= byte(element&masks[chunkBits-1]) << outByteBits
300				outByteBits += chunkBits
301			}
302
303			elementBitsDone += chunkBits
304			element >>= chunkBits
305		}
306	}
307
308	if outByteBits > 0 {
309		out[0] = outByte
310		out = out[1:]
311	}
312
313	return out
314}
315
316func (s *scalar) decode(in []byte, bits int) ([]byte, bool) {
317	var inByte byte
318	inByteBitsLeft := 0
319
320	for i := range s {
321		var element uint16
322		elementBitsDone := 0
323
324		for elementBitsDone < bits {
325			if inByteBitsLeft == 0 {
326				inByte = in[0]
327				in = in[1:]
328				inByteBitsLeft = 8
329			}
330
331			chunkBits := bits - elementBitsDone
332			if chunkBits > inByteBitsLeft {
333				chunkBits = inByteBitsLeft
334			}
335
336			element |= (uint16(inByte) & masks[chunkBits-1]) << elementBitsDone
337			inByteBitsLeft -= chunkBits
338			inByte >>= chunkBits
339
340			elementBitsDone += chunkBits
341		}
342
343		if element >= prime {
344			return nil, false
345		}
346		s[i] = element
347	}
348
349	return in, true
350}
351
352func (s *scalar) compress(bits int) {
353	for i := range s {
354		s[i] = compress(s[i], bits)
355	}
356}
357
358func (s *scalar) decompress(bits int) {
359	for i := range s {
360		s[i] = decompress(s[i], bits)
361	}
362}
363
364type vector [rank]scalar
365
366func (v *vector) zero() {
367	for i := range v {
368		v[i].zero()
369	}
370}
371
372func (v *vector) ntt() {
373	for i := range v {
374		v[i].ntt()
375	}
376}
377
378func (v *vector) inverseNTT() {
379	for i := range v {
380		v[i].inverseNTT()
381	}
382}
383
384func (v *vector) add(b *vector) {
385	for i := range v {
386		v[i].add(&b[i])
387	}
388}
389
390func (out *vector) mult(m *matrix, v *vector) {
391	out.zero()
392	var product scalar
393	for i := 0; i < rank; i++ {
394		for j := 0; j < rank; j++ {
395			product.mult(&m[i][j], &v[j])
396			out[i].add(&product)
397		}
398	}
399}
400
401func (out *vector) multTranspose(m *matrix, v *vector) {
402	out.zero()
403	var product scalar
404	for i := 0; i < rank; i++ {
405		for j := 0; j < rank; j++ {
406			product.mult(&m[j][i], &v[j])
407			out[i].add(&product)
408		}
409	}
410}
411
412func (v *vector) generateSecretEta2(counter *byte, seed *[32]byte) {
413	var input [33]byte
414	copy(input[:], seed[:])
415	for i := range v {
416		input[32] = *counter
417		*counter++
418		v[i].centeredBinomialEta2(&input)
419	}
420}
421
422func (v *vector) encode(out []byte, bits int) []byte {
423	for i := range v {
424		out = v[i].encode(out, bits)
425	}
426	return out
427}
428
429func (v *vector) decode(out []byte, bits int) ([]byte, bool) {
430	var ok bool
431	for i := range v {
432		out, ok = v[i].decode(out, bits)
433		if !ok {
434			return nil, false
435		}
436	}
437
438	return out, true
439}
440
441func (v *vector) compress(bits int) {
442	for i := range v {
443		v[i].compress(bits)
444	}
445}
446
447func (v *vector) decompress(bits int) {
448	for i := range v {
449		v[i].decompress(bits)
450	}
451}
452
453type matrix [rank][rank]scalar
454
455func (m *matrix) expand(rho *[32]byte) {
456	shake := sha3.NewShake128()
457
458	var input [34]byte
459	copy(input[:], rho[:])
460
461	for i := 0; i < rank; i++ {
462		for j := 0; j < rank; j++ {
463			input[32] = byte(i)
464			input[33] = byte(j)
465
466			shake.Reset()
467			shake.Write(input[:])
468			m[i][j].fromKeccakVartime(shake)
469		}
470	}
471}
472
473type PublicKey struct {
474	t             vector
475	rho           [32]byte
476	publicKeyHash [32]byte
477	m             matrix
478}
479
480func UnmarshalPublicKey(data *[PublicKeySize]byte) (*PublicKey, bool) {
481	var ret PublicKey
482	ret.publicKeyHash = sha3.Sum256(data[:])
483	in, ok := ret.t.decode(data[:], log2Prime)
484	if !ok {
485		return nil, false
486	}
487	copy(ret.rho[:], in)
488	ret.m.expand(&ret.rho)
489	return &ret, true
490}
491
492func (pub *PublicKey) Marshal() *[PublicKeySize]byte {
493	var ret [PublicKeySize]byte
494	out := pub.t.encode(ret[:], log2Prime)
495	copy(out, pub.rho[:])
496	return &ret
497}
498
499func (pub *PublicKey) encryptCPA(message, entropy *[32]byte) *[CiphertextSize]byte {
500	var counter uint8
501	var secret, error vector
502	secret.generateSecretEta2(&counter, entropy)
503	error.generateSecretEta2(&counter, entropy)
504	secret.ntt()
505
506	var input [33]byte
507	copy(input[:], entropy[:])
508	input[32] = counter
509	var scalarError scalar
510	scalarError.centeredBinomialEta2(&input)
511
512	var u vector
513	u.mult(&pub.m, &secret)
514	u.inverseNTT()
515	u.add(&error)
516
517	var v scalar
518	v.innerProduct(&pub.t, &secret)
519	v.inverseNTT()
520	v.add(&scalarError)
521
522	out := make([]byte, CiphertextSize)
523	var expandedMessage scalar
524	expandedMessage.decode(message[:], 1)
525	expandedMessage.decompress(1)
526	v.add(&expandedMessage)
527	u.compress(du)
528	it := u.encode(out, du)
529	v.compress(dv)
530	v.encode(it, dv)
531	return (*[CiphertextSize]byte)(out)
532}
533
534func (pub *PublicKey) Encap(outSharedSecret []byte, entropy *[32]byte) *[CiphertextSize]byte {
535	var input [64]byte
536	copy(input[:], entropy[:])
537	copy(input[32:], pub.publicKeyHash[:])
538	prekeyAndRandomness := sha3.Sum512(input[:])
539	ciphertext := pub.encryptCPA(entropy, (*[32]byte)(prekeyAndRandomness[32:]))
540	ciphertextHash := sha3.Sum256(ciphertext[:])
541	copy(prekeyAndRandomness[32:], ciphertextHash[:])
542	sha3.ShakeSum256(outSharedSecret, prekeyAndRandomness[:])
543	return ciphertext
544}
545
546type PrivateKey struct {
547	PublicKey
548	s               vector
549	foFailureSecret [32]byte
550}
551
552func NewPrivateKey(entropy *[64]byte) (*PrivateKey, *[PublicKeySize]byte) {
553	hashed := sha3.Sum512(entropy[:32])
554	rho := (*[32]byte)(hashed[:32])
555	sigma := (*[32]byte)(hashed[32:])
556	ret := new(PrivateKey)
557	copy(ret.foFailureSecret[:], entropy[32:])
558	copy(ret.rho[:], rho[:])
559	ret.m.expand(rho)
560	counter := uint8(0)
561	ret.s.generateSecretEta2(&counter, sigma)
562	ret.s.ntt()
563	var error vector
564	error.generateSecretEta2(&counter, sigma)
565	error.ntt()
566	ret.t.multTranspose(&ret.m, &ret.s)
567	ret.t.add(&error)
568
569	marshalledPublicKey := ret.PublicKey.Marshal()
570	ret.publicKeyHash = sha3.Sum256(marshalledPublicKey[:])
571
572	return ret, marshalledPublicKey
573}
574
575func (priv *PrivateKey) decryptCPA(ciphertext *[CiphertextSize]byte) [32]byte {
576	var u vector
577	u.decode(ciphertext[:], du)
578	u.decompress(du)
579	u.ntt()
580
581	var v scalar
582	v.decode(ciphertext[compressedVectorSize:], dv)
583	v.decompress(dv)
584
585	var mask scalar
586	mask.innerProduct(&priv.s, &u)
587	mask.inverseNTT()
588	v.sub(&mask)
589	v.compress(1)
590	var out [32]byte
591	v.encode(out[:], 1)
592	return out
593}
594
595func (priv *PrivateKey) Decap(outSharedSecret []byte, ciphertext *[CiphertextSize]byte) {
596	decrypted := priv.decryptCPA(ciphertext)
597	h := sha3.New512()
598	h.Write(decrypted[:])
599	h.Write(priv.publicKeyHash[:])
600	prekeyAndRandomness := h.Sum(nil)
601	expectedCiphertext := priv.encryptCPA(&decrypted, (*[32]byte)(prekeyAndRandomness[32:]))
602	equal := subtle.ConstantTimeCompare(ciphertext[:], expectedCiphertext[:])
603	var secret [32]byte
604	for i := range secret {
605		secret[i] = byte(subtle.ConstantTimeSelect(equal, int(prekeyAndRandomness[i]), int(priv.foFailureSecret[i])))
606	}
607	ciphertextHash := sha3.Sum256(ciphertext[:])
608
609	shake := sha3.NewShake256()
610	shake.Write(secret[:])
611	shake.Write(ciphertextHash[:])
612	shake.Read(outSharedSecret)
613}
614
615func (priv *PrivateKey) Marshal() *[PrivateKeySize]byte {
616	var ret [PrivateKeySize]byte
617	out := priv.s.encode(ret[:], log2Prime)
618	publicKey := priv.PublicKey.Marshal()
619	n := copy(out, publicKey[:])
620	out = out[n:]
621	n = copy(out, priv.publicKeyHash[:])
622	out = out[n:]
623	copy(out, priv.foFailureSecret[:])
624	return &ret
625}
626