1/* Copyright (c) 2023, Google Inc. 2 * 3 * Permission to use, copy, modify, and/or distribute this software for any 4 * purpose with or without fee is hereby granted, provided that the above 5 * copyright notice and this permission notice appear in all copies. 6 * 7 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY 10 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION 12 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN 13 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ 14 15package kyber 16 17// This code is ported from kyber.c. 18 19import ( 20 "crypto/subtle" 21 "golang.org/x/crypto/sha3" 22 "io" 23) 24 25const( 26 CiphertextSize = 1088 27 PublicKeySize = 1184 28 PrivateKeySize = 2400 29) 30 31const ( 32 degree = 256 33 rank = 3 34 prime = 3329 35 log2Prime = 12 36 halfPrime = (prime - 1) / 2 37 du = 10 38 dv = 4 39 inverseDegree = 3303 40 encodedVectorSize = log2Prime * degree / 8 * rank 41 compressedVectorSize = du * rank * degree / 8 42 barrettMultiplier = 5039 43 barrettShift = 24 44) 45 46func reduceOnce(x uint16) uint16 { 47 if x >= 2*prime { 48 panic("reduce_once: value out of range") 49 } 50 subtracted := x - prime 51 mask := 0 - (subtracted >> 15) 52 return (mask & x) | (^mask & subtracted) 53} 54 55func reduce(x uint32) uint16 { 56 if x >= prime+2*prime*prime { 57 panic("reduce: value out of range") 58 } 59 product := uint64(x) * barrettMultiplier 60 quotient := uint32(product >> barrettShift) 61 remainder := uint32(x) - quotient*prime 62 return reduceOnce(uint16(remainder)) 63} 64 65// lt returns 0xff..f if a < b and 0 otherwise 66func lt(a, b uint32) uint32 { 67 return uint32(0 - int32(a^((a^b)|((a-b)^a)))>>31) 68} 69 70// Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping 71// numbers close to each other together. The formula used is 72// round(2^|bits|/prime*x) mod 2^|bits|. 73// Uses Barrett reduction to achieve constant time. Since we need both the 74// remainder (for rounding) and the quotient (as the result), we cannot use 75// |reduce| here, but need to do the Barrett reduction directly. 76func compress(x uint16, bits int) uint16 { 77 product := uint32(x) << bits 78 quotient := uint32((uint64(product) * barrettMultiplier) >> barrettShift) 79 remainder := product - quotient*prime 80 81 // Adjust the quotient to round correctly: 82 // 0 <= remainder <= halfPrime round to 0 83 // halfPrime < remainder <= prime + halfPrime round to 1 84 // prime + halfPrime < remainder < 2 * prime round to 2 85 quotient += 1 & lt(halfPrime, remainder) 86 quotient += 1 & lt(prime+halfPrime, remainder) 87 return uint16(quotient) & ((1 << bits) - 1) 88} 89 90func decompress(x uint16, bits int) uint16 { 91 product := uint32(x) * prime 92 power := uint32(1) << bits 93 // This is |product| % power, since |power| is a power of 2. 94 remainder := product & (power - 1) 95 // This is |product| / power, since |power| is a power of 2. 96 lower := product >> bits 97 // The rounding logic works since the first half of numbers mod |power| have a 98 // 0 as first bit, and the second half has a 1 as first bit, since |power| is 99 // a power of 2. As a 12 bit number, |remainder| is always positive, so we 100 // will shift in 0s for a right shift. 101 return uint16(lower + (remainder >> (bits - 1))) 102} 103 104type scalar [degree]uint16 105 106func (s *scalar) zero() { 107 for i := range s { 108 s[i] = 0 109 } 110} 111 112// This bit of Python will be referenced in some of the following comments: 113// 114// p = 3329 115// 116// def bitreverse(i): 117// ret = 0 118// for n in range(7): 119// bit = i & 1 120// ret <<= 1 121// ret |= bit 122// i >>= 1 123// return ret 124 125// kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)] 126var nttRoots = [128]uint16{ 127 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 128 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, 129 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756, 130 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915, 131 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648, 132 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100, 133 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789, 134 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641, 135 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757, 136 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, 137 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154, 138} 139 140func (s *scalar) ntt() { 141 offset := degree 142 for step := 1; step < degree/2; step <<= 1 { 143 offset >>= 1 144 k := 0 145 for i := 0; i < step; i++ { 146 stepRoot := uint32(nttRoots[i+step]) 147 for j := k; j < k+offset; j++ { 148 odd := reduce(stepRoot * uint32(s[j+offset])) 149 even := s[j] 150 s[j] = reduceOnce(odd + even) 151 s[j+offset] = reduceOnce(even - odd + prime) 152 } 153 k += 2 * offset 154 } 155 } 156} 157 158// kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)] 159var inverseNTTRoots = [128]uint16{ 160 1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543, 161 2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447, 2794, 1235, 1903, 162 1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855, 163 2681, 1848, 712, 682, 927, 1795, 461, 1891, 2877, 2522, 1894, 1010, 164 1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132, 165 1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607, 166 2443, 554, 1179, 2186, 2303, 2926, 2237, 525, 735, 863, 2768, 1230, 167 2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745, 168 2688, 3061, 992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482, 169 1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096, 48, 667, 1920, 170 2229, 1041, 2606, 1692, 680, 2746, 568, 3312, 171} 172 173func (s *scalar) inverseNTT() { 174 step := degree / 2 175 for offset := 2; offset < degree; offset <<= 1 { 176 step >>= 1 177 k := 0 178 for i := 0; i < step; i++ { 179 stepRoot := uint32(inverseNTTRoots[i+step]) 180 for j := k; j < k+offset; j++ { 181 odd := s[j+offset] 182 even := s[j] 183 s[j] = reduceOnce(odd + even) 184 s[j+offset] = reduce(stepRoot * uint32(even-odd+prime)) 185 } 186 k += 2 * offset 187 } 188 } 189 for i := range s { 190 s[i] = reduce(uint32(s[i]) * inverseDegree) 191 } 192} 193 194func (s *scalar) add(b *scalar) { 195 for i := range s { 196 s[i] = reduceOnce(s[i] + b[i]) 197 } 198} 199 200func (s *scalar) sub(b *scalar) { 201 for i := range s { 202 s[i] = reduceOnce(s[i] - b[i] + prime) 203 } 204} 205 206// kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)] 207var modRoots = [128]uint16{ 208 17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606, 209 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096, 210 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678, 211 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642, 212 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992, 213 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109, 214 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010, 215 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735, 216 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179, 217 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300, 218 2110, 1219, 2935, 394, 885, 2444, 2154, 1175, 219} 220 221func (s *scalar) mult(a, b *scalar) { 222 for i := 0; i < degree/2; i++ { 223 realReal := uint32(a[2*i]) * uint32(b[2*i]) 224 imgImg := uint32(a[2*i+1]) * uint32(b[2*i+1]) 225 realImg := uint32(a[2*i]) * uint32(b[2*i+1]) 226 imgReal := uint32(a[2*i+1]) * uint32(b[2*i]) 227 s[2*i] = reduce(realReal + uint32(reduce(imgImg))*uint32(modRoots[i])) 228 s[2*i+1] = reduce(imgReal + realImg) 229 } 230} 231 232func (s *scalar) innerProduct(left, right *vector) { 233 s.zero() 234 var product scalar 235 for i := range left { 236 product.mult(&left[i], &right[i]) 237 s.add(&product) 238 } 239} 240 241func (s *scalar) fromKeccakVartime(keccak io.Reader) { 242 var buf [3]byte 243 for i := 0; i < len(s); { 244 keccak.Read(buf[:]) 245 d1 := uint16(buf[0]) + 256*uint16(buf[1]%16) 246 d2 := uint16(buf[1])/16 + 16*uint16(buf[2]) 247 if d1 < prime { 248 s[i] = d1 249 i++ 250 } 251 if d2 < prime && i < len(s) { 252 s[i] = d2 253 i++ 254 } 255 } 256} 257 258func (s *scalar) centeredBinomialEta2(input *[33]byte) { 259 var entropy [128]byte 260 sha3.ShakeSum256(entropy[:], input[:]) 261 262 for i := 0; i < len(s); i += 2 { 263 b := uint16(entropy[i/2]) 264 265 value := uint16(prime) 266 value += (b & 1) + ((b >> 1) & 1) 267 value -= ((b >> 2) & 1) + ((b >> 3) & 1) 268 s[i] = reduceOnce(value) 269 270 b >>= 4 271 value = prime 272 value += (b & 1) + ((b >> 1) & 1) 273 value -= ((b >> 2) & 1) + ((b >> 3) & 1) 274 s[i+1] = reduceOnce(value) 275 } 276} 277 278var masks = [8]uint16{0x01, 0x03, 0x07, 0x0f, 0x1f, 0x3f, 0x7f, 0xff} 279 280func (s *scalar) encode(out []byte, bits int) []byte { 281 var outByte byte 282 outByteBits := 0 283 284 for i := range s { 285 element := s[i] 286 elementBitsDone := 0 287 288 for elementBitsDone < bits { 289 chunkBits := bits - elementBitsDone 290 outBitsRemaining := 8 - outByteBits 291 if chunkBits >= outBitsRemaining { 292 chunkBits = outBitsRemaining 293 outByte |= byte(element&masks[chunkBits-1]) << outByteBits 294 out[0] = outByte 295 out = out[1:] 296 outByteBits = 0 297 outByte = 0 298 } else { 299 outByte |= byte(element&masks[chunkBits-1]) << outByteBits 300 outByteBits += chunkBits 301 } 302 303 elementBitsDone += chunkBits 304 element >>= chunkBits 305 } 306 } 307 308 if outByteBits > 0 { 309 out[0] = outByte 310 out = out[1:] 311 } 312 313 return out 314} 315 316func (s *scalar) decode(in []byte, bits int) ([]byte, bool) { 317 var inByte byte 318 inByteBitsLeft := 0 319 320 for i := range s { 321 var element uint16 322 elementBitsDone := 0 323 324 for elementBitsDone < bits { 325 if inByteBitsLeft == 0 { 326 inByte = in[0] 327 in = in[1:] 328 inByteBitsLeft = 8 329 } 330 331 chunkBits := bits - elementBitsDone 332 if chunkBits > inByteBitsLeft { 333 chunkBits = inByteBitsLeft 334 } 335 336 element |= (uint16(inByte) & masks[chunkBits-1]) << elementBitsDone 337 inByteBitsLeft -= chunkBits 338 inByte >>= chunkBits 339 340 elementBitsDone += chunkBits 341 } 342 343 if element >= prime { 344 return nil, false 345 } 346 s[i] = element 347 } 348 349 return in, true 350} 351 352func (s *scalar) compress(bits int) { 353 for i := range s { 354 s[i] = compress(s[i], bits) 355 } 356} 357 358func (s *scalar) decompress(bits int) { 359 for i := range s { 360 s[i] = decompress(s[i], bits) 361 } 362} 363 364type vector [rank]scalar 365 366func (v *vector) zero() { 367 for i := range v { 368 v[i].zero() 369 } 370} 371 372func (v *vector) ntt() { 373 for i := range v { 374 v[i].ntt() 375 } 376} 377 378func (v *vector) inverseNTT() { 379 for i := range v { 380 v[i].inverseNTT() 381 } 382} 383 384func (v *vector) add(b *vector) { 385 for i := range v { 386 v[i].add(&b[i]) 387 } 388} 389 390func (out *vector) mult(m *matrix, v *vector) { 391 out.zero() 392 var product scalar 393 for i := 0; i < rank; i++ { 394 for j := 0; j < rank; j++ { 395 product.mult(&m[i][j], &v[j]) 396 out[i].add(&product) 397 } 398 } 399} 400 401func (out *vector) multTranspose(m *matrix, v *vector) { 402 out.zero() 403 var product scalar 404 for i := 0; i < rank; i++ { 405 for j := 0; j < rank; j++ { 406 product.mult(&m[j][i], &v[j]) 407 out[i].add(&product) 408 } 409 } 410} 411 412func (v *vector) generateSecretEta2(counter *byte, seed *[32]byte) { 413 var input [33]byte 414 copy(input[:], seed[:]) 415 for i := range v { 416 input[32] = *counter 417 *counter++ 418 v[i].centeredBinomialEta2(&input) 419 } 420} 421 422func (v *vector) encode(out []byte, bits int) []byte { 423 for i := range v { 424 out = v[i].encode(out, bits) 425 } 426 return out 427} 428 429func (v *vector) decode(out []byte, bits int) ([]byte, bool) { 430 var ok bool 431 for i := range v { 432 out, ok = v[i].decode(out, bits) 433 if !ok { 434 return nil, false 435 } 436 } 437 438 return out, true 439} 440 441func (v *vector) compress(bits int) { 442 for i := range v { 443 v[i].compress(bits) 444 } 445} 446 447func (v *vector) decompress(bits int) { 448 for i := range v { 449 v[i].decompress(bits) 450 } 451} 452 453type matrix [rank][rank]scalar 454 455func (m *matrix) expand(rho *[32]byte) { 456 shake := sha3.NewShake128() 457 458 var input [34]byte 459 copy(input[:], rho[:]) 460 461 for i := 0; i < rank; i++ { 462 for j := 0; j < rank; j++ { 463 input[32] = byte(i) 464 input[33] = byte(j) 465 466 shake.Reset() 467 shake.Write(input[:]) 468 m[i][j].fromKeccakVartime(shake) 469 } 470 } 471} 472 473type PublicKey struct { 474 t vector 475 rho [32]byte 476 publicKeyHash [32]byte 477 m matrix 478} 479 480func UnmarshalPublicKey(data *[PublicKeySize]byte) (*PublicKey, bool) { 481 var ret PublicKey 482 ret.publicKeyHash = sha3.Sum256(data[:]) 483 in, ok := ret.t.decode(data[:], log2Prime) 484 if !ok { 485 return nil, false 486 } 487 copy(ret.rho[:], in) 488 ret.m.expand(&ret.rho) 489 return &ret, true 490} 491 492func (pub *PublicKey) Marshal() *[PublicKeySize]byte { 493 var ret [PublicKeySize]byte 494 out := pub.t.encode(ret[:], log2Prime) 495 copy(out, pub.rho[:]) 496 return &ret 497} 498 499func (pub *PublicKey) encryptCPA(message, entropy *[32]byte) *[CiphertextSize]byte { 500 var counter uint8 501 var secret, error vector 502 secret.generateSecretEta2(&counter, entropy) 503 error.generateSecretEta2(&counter, entropy) 504 secret.ntt() 505 506 var input [33]byte 507 copy(input[:], entropy[:]) 508 input[32] = counter 509 var scalarError scalar 510 scalarError.centeredBinomialEta2(&input) 511 512 var u vector 513 u.mult(&pub.m, &secret) 514 u.inverseNTT() 515 u.add(&error) 516 517 var v scalar 518 v.innerProduct(&pub.t, &secret) 519 v.inverseNTT() 520 v.add(&scalarError) 521 522 out := make([]byte, CiphertextSize) 523 var expandedMessage scalar 524 expandedMessage.decode(message[:], 1) 525 expandedMessage.decompress(1) 526 v.add(&expandedMessage) 527 u.compress(du) 528 it := u.encode(out, du) 529 v.compress(dv) 530 v.encode(it, dv) 531 return (*[CiphertextSize]byte)(out) 532} 533 534func (pub *PublicKey) Encap(outSharedSecret []byte, entropy *[32]byte) *[CiphertextSize]byte { 535 var input [64]byte 536 copy(input[:], entropy[:]) 537 copy(input[32:], pub.publicKeyHash[:]) 538 prekeyAndRandomness := sha3.Sum512(input[:]) 539 ciphertext := pub.encryptCPA(entropy, (*[32]byte)(prekeyAndRandomness[32:])) 540 ciphertextHash := sha3.Sum256(ciphertext[:]) 541 copy(prekeyAndRandomness[32:], ciphertextHash[:]) 542 sha3.ShakeSum256(outSharedSecret, prekeyAndRandomness[:]) 543 return ciphertext 544} 545 546type PrivateKey struct { 547 PublicKey 548 s vector 549 foFailureSecret [32]byte 550} 551 552func NewPrivateKey(entropy *[64]byte) (*PrivateKey, *[PublicKeySize]byte) { 553 hashed := sha3.Sum512(entropy[:32]) 554 rho := (*[32]byte)(hashed[:32]) 555 sigma := (*[32]byte)(hashed[32:]) 556 ret := new(PrivateKey) 557 copy(ret.foFailureSecret[:], entropy[32:]) 558 copy(ret.rho[:], rho[:]) 559 ret.m.expand(rho) 560 counter := uint8(0) 561 ret.s.generateSecretEta2(&counter, sigma) 562 ret.s.ntt() 563 var error vector 564 error.generateSecretEta2(&counter, sigma) 565 error.ntt() 566 ret.t.multTranspose(&ret.m, &ret.s) 567 ret.t.add(&error) 568 569 marshalledPublicKey := ret.PublicKey.Marshal() 570 ret.publicKeyHash = sha3.Sum256(marshalledPublicKey[:]) 571 572 return ret, marshalledPublicKey 573} 574 575func (priv *PrivateKey) decryptCPA(ciphertext *[CiphertextSize]byte) [32]byte { 576 var u vector 577 u.decode(ciphertext[:], du) 578 u.decompress(du) 579 u.ntt() 580 581 var v scalar 582 v.decode(ciphertext[compressedVectorSize:], dv) 583 v.decompress(dv) 584 585 var mask scalar 586 mask.innerProduct(&priv.s, &u) 587 mask.inverseNTT() 588 v.sub(&mask) 589 v.compress(1) 590 var out [32]byte 591 v.encode(out[:], 1) 592 return out 593} 594 595func (priv *PrivateKey) Decap(outSharedSecret []byte, ciphertext *[CiphertextSize]byte) { 596 decrypted := priv.decryptCPA(ciphertext) 597 h := sha3.New512() 598 h.Write(decrypted[:]) 599 h.Write(priv.publicKeyHash[:]) 600 prekeyAndRandomness := h.Sum(nil) 601 expectedCiphertext := priv.encryptCPA(&decrypted, (*[32]byte)(prekeyAndRandomness[32:])) 602 equal := subtle.ConstantTimeCompare(ciphertext[:], expectedCiphertext[:]) 603 var secret [32]byte 604 for i := range secret { 605 secret[i] = byte(subtle.ConstantTimeSelect(equal, int(prekeyAndRandomness[i]), int(priv.foFailureSecret[i]))) 606 } 607 ciphertextHash := sha3.Sum256(ciphertext[:]) 608 609 shake := sha3.NewShake256() 610 shake.Write(secret[:]) 611 shake.Write(ciphertextHash[:]) 612 shake.Read(outSharedSecret) 613} 614 615func (priv *PrivateKey) Marshal() *[PrivateKeySize]byte { 616 var ret [PrivateKeySize]byte 617 out := priv.s.encode(ret[:], log2Prime) 618 publicKey := priv.PublicKey.Marshal() 619 n := copy(out, publicKey[:]) 620 out = out[n:] 621 n = copy(out, priv.publicKeyHash[:]) 622 out = out[n:] 623 copy(out, priv.foFailureSecret[:]) 624 return &ret 625} 626