1// Copyright 2023 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5// Package mlkem768 implements the quantum-resistant key encapsulation method 6// ML-KEM (formerly known as Kyber). 7// 8// Only the recommended ML-KEM-768 parameter set is provided. 9// 10// The version currently implemented is the one specified by [NIST FIPS 203 ipd], 11// with the unintentional transposition of the matrix A reverted to match the 12// behavior of [Kyber version 3.0]. Future versions of this package might 13// introduce backwards incompatible changes to implement changes to FIPS 203. 14// 15// [Kyber version 3.0]: https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf 16// [NIST FIPS 203 ipd]: https://doi.org/10.6028/NIST.FIPS.203.ipd 17package mlkem768 18 19// This package targets security, correctness, simplicity, readability, and 20// reviewability as its primary goals. All critical operations are performed in 21// constant time. 22// 23// Variable and function names, as well as code layout, are selected to 24// facilitate reviewing the implementation against the NIST FIPS 203 ipd 25// document. 26// 27// Reviewers unfamiliar with polynomials or linear algebra might find the 28// background at https://words.filippo.io/kyber-math/ useful. 29 30import ( 31 "crypto/rand" 32 "crypto/subtle" 33 "errors" 34 "internal/byteorder" 35 36 "golang.org/x/crypto/sha3" 37) 38 39const ( 40 // ML-KEM global constants. 41 n = 256 42 q = 3329 43 44 log2q = 12 45 46 // ML-KEM-768 parameters. The code makes assumptions based on these values, 47 // they can't be changed blindly. 48 k = 3 49 η = 2 50 du = 10 51 dv = 4 52 53 // encodingSizeX is the byte size of a ringElement or nttElement encoded 54 // by ByteEncode_X (FIPS 203 (DRAFT), Algorithm 4). 55 encodingSize12 = n * log2q / 8 56 encodingSize10 = n * du / 8 57 encodingSize4 = n * dv / 8 58 encodingSize1 = n * 1 / 8 59 60 messageSize = encodingSize1 61 decryptionKeySize = k * encodingSize12 62 encryptionKeySize = k*encodingSize12 + 32 63 64 CiphertextSize = k*encodingSize10 + encodingSize4 65 EncapsulationKeySize = encryptionKeySize 66 DecapsulationKeySize = decryptionKeySize + encryptionKeySize + 32 + 32 67 SharedKeySize = 32 68 SeedSize = 32 + 32 69) 70 71// A DecapsulationKey is the secret key used to decapsulate a shared key from a 72// ciphertext. It includes various precomputed values. 73type DecapsulationKey struct { 74 dk [DecapsulationKeySize]byte 75 encryptionKey 76 decryptionKey 77} 78 79// Bytes returns the extended encoding of the decapsulation key, according to 80// FIPS 203 (DRAFT). 81func (dk *DecapsulationKey) Bytes() []byte { 82 var b [DecapsulationKeySize]byte 83 copy(b[:], dk.dk[:]) 84 return b[:] 85} 86 87// EncapsulationKey returns the public encapsulation key necessary to produce 88// ciphertexts. 89func (dk *DecapsulationKey) EncapsulationKey() []byte { 90 var b [EncapsulationKeySize]byte 91 copy(b[:], dk.dk[decryptionKeySize:]) 92 return b[:] 93} 94 95// encryptionKey is the parsed and expanded form of a PKE encryption key. 96type encryptionKey struct { 97 t [k]nttElement // ByteDecode₁₂(ek[:384k]) 98 A [k * k]nttElement // A[i*k+j] = sampleNTT(ρ, j, i) 99} 100 101// decryptionKey is the parsed and expanded form of a PKE decryption key. 102type decryptionKey struct { 103 s [k]nttElement // ByteDecode₁₂(dk[:decryptionKeySize]) 104} 105 106// GenerateKey generates a new decapsulation key, drawing random bytes from 107// crypto/rand. The decapsulation key must be kept secret. 108func GenerateKey() (*DecapsulationKey, error) { 109 // The actual logic is in a separate function to outline this allocation. 110 dk := &DecapsulationKey{} 111 return generateKey(dk) 112} 113 114func generateKey(dk *DecapsulationKey) (*DecapsulationKey, error) { 115 var d [32]byte 116 if _, err := rand.Read(d[:]); err != nil { 117 return nil, errors.New("mlkem768: crypto/rand Read failed: " + err.Error()) 118 } 119 var z [32]byte 120 if _, err := rand.Read(z[:]); err != nil { 121 return nil, errors.New("mlkem768: crypto/rand Read failed: " + err.Error()) 122 } 123 return kemKeyGen(dk, &d, &z), nil 124} 125 126// NewKeyFromSeed deterministically generates a decapsulation key from a 64-byte 127// seed in the "d || z" form. The seed must be uniformly random. 128func NewKeyFromSeed(seed []byte) (*DecapsulationKey, error) { 129 // The actual logic is in a separate function to outline this allocation. 130 dk := &DecapsulationKey{} 131 return newKeyFromSeed(dk, seed) 132} 133 134func newKeyFromSeed(dk *DecapsulationKey, seed []byte) (*DecapsulationKey, error) { 135 if len(seed) != SeedSize { 136 return nil, errors.New("mlkem768: invalid seed length") 137 } 138 d := (*[32]byte)(seed[:32]) 139 z := (*[32]byte)(seed[32:]) 140 return kemKeyGen(dk, d, z), nil 141} 142 143// NewKeyFromExtendedEncoding parses a decapsulation key from its FIPS 203 144// (DRAFT) extended encoding. 145func NewKeyFromExtendedEncoding(decapsulationKey []byte) (*DecapsulationKey, error) { 146 // The actual logic is in a separate function to outline this allocation. 147 dk := &DecapsulationKey{} 148 return newKeyFromExtendedEncoding(dk, decapsulationKey) 149} 150 151func newKeyFromExtendedEncoding(dk *DecapsulationKey, dkBytes []byte) (*DecapsulationKey, error) { 152 if len(dkBytes) != DecapsulationKeySize { 153 return nil, errors.New("mlkem768: invalid decapsulation key length") 154 } 155 156 // Note that we don't check that H(ek) matches ekPKE, as that's not 157 // specified in FIPS 203 (DRAFT). This is one reason to prefer the seed 158 // private key format. 159 dk.dk = [DecapsulationKeySize]byte(dkBytes) 160 161 dkPKE := dkBytes[:decryptionKeySize] 162 if err := parseDK(&dk.decryptionKey, dkPKE); err != nil { 163 return nil, err 164 } 165 166 ekPKE := dkBytes[decryptionKeySize : decryptionKeySize+encryptionKeySize] 167 if err := parseEK(&dk.encryptionKey, ekPKE); err != nil { 168 return nil, err 169 } 170 171 return dk, nil 172} 173 174// kemKeyGen generates a decapsulation key. 175// 176// It implements ML-KEM.KeyGen according to FIPS 203 (DRAFT), Algorithm 15, and 177// K-PKE.KeyGen according to FIPS 203 (DRAFT), Algorithm 12. The two are merged 178// to save copies and allocations. 179func kemKeyGen(dk *DecapsulationKey, d, z *[32]byte) *DecapsulationKey { 180 if dk == nil { 181 dk = &DecapsulationKey{} 182 } 183 184 G := sha3.Sum512(d[:]) 185 ρ, σ := G[:32], G[32:] 186 187 A := &dk.A 188 for i := byte(0); i < k; i++ { 189 for j := byte(0); j < k; j++ { 190 // Note that this is consistent with Kyber round 3, rather than with 191 // the initial draft of FIPS 203, because NIST signaled that the 192 // change was involuntary and will be reverted. 193 A[i*k+j] = sampleNTT(ρ, j, i) 194 } 195 } 196 197 var N byte 198 s := &dk.s 199 for i := range s { 200 s[i] = ntt(samplePolyCBD(σ, N)) 201 N++ 202 } 203 e := make([]nttElement, k) 204 for i := range e { 205 e[i] = ntt(samplePolyCBD(σ, N)) 206 N++ 207 } 208 209 t := &dk.t 210 for i := range t { // t = A ◦ s + e 211 t[i] = e[i] 212 for j := range s { 213 t[i] = polyAdd(t[i], nttMul(A[i*k+j], s[j])) 214 } 215 } 216 217 // dkPKE ← ByteEncode₁₂(s) 218 // ekPKE ← ByteEncode₁₂(t) || ρ 219 // ek ← ekPKE 220 // dk ← dkPKE || ek || H(ek) || z 221 dkB := dk.dk[:0] 222 223 for i := range s { 224 dkB = polyByteEncode(dkB, s[i]) 225 } 226 227 for i := range t { 228 dkB = polyByteEncode(dkB, t[i]) 229 } 230 dkB = append(dkB, ρ...) 231 232 H := sha3.New256() 233 H.Write(dkB[decryptionKeySize:]) 234 dkB = H.Sum(dkB) 235 236 dkB = append(dkB, z[:]...) 237 238 if len(dkB) != len(dk.dk) { 239 panic("mlkem768: internal error: invalid decapsulation key size") 240 } 241 242 return dk 243} 244 245// Encapsulate generates a shared key and an associated ciphertext from an 246// encapsulation key, drawing random bytes from crypto/rand. 247// If the encapsulation key is not valid, Encapsulate returns an error. 248// 249// The shared key must be kept secret. 250func Encapsulate(encapsulationKey []byte) (ciphertext, sharedKey []byte, err error) { 251 // The actual logic is in a separate function to outline this allocation. 252 var cc [CiphertextSize]byte 253 return encapsulate(&cc, encapsulationKey) 254} 255 256func encapsulate(cc *[CiphertextSize]byte, encapsulationKey []byte) (ciphertext, sharedKey []byte, err error) { 257 if len(encapsulationKey) != EncapsulationKeySize { 258 return nil, nil, errors.New("mlkem768: invalid encapsulation key length") 259 } 260 var m [messageSize]byte 261 if _, err := rand.Read(m[:]); err != nil { 262 return nil, nil, errors.New("mlkem768: crypto/rand Read failed: " + err.Error()) 263 } 264 return kemEncaps(cc, encapsulationKey, &m) 265} 266 267// kemEncaps generates a shared key and an associated ciphertext. 268// 269// It implements ML-KEM.Encaps according to FIPS 203 (DRAFT), Algorithm 16. 270func kemEncaps(cc *[CiphertextSize]byte, ek []byte, m *[messageSize]byte) (c, K []byte, err error) { 271 if cc == nil { 272 cc = &[CiphertextSize]byte{} 273 } 274 275 H := sha3.Sum256(ek[:]) 276 g := sha3.New512() 277 g.Write(m[:]) 278 g.Write(H[:]) 279 G := g.Sum(nil) 280 K, r := G[:SharedKeySize], G[SharedKeySize:] 281 var ex encryptionKey 282 if err := parseEK(&ex, ek[:]); err != nil { 283 return nil, nil, err 284 } 285 c = pkeEncrypt(cc, &ex, m, r) 286 return c, K, nil 287} 288 289// parseEK parses an encryption key from its encoded form. 290// 291// It implements the initial stages of K-PKE.Encrypt according to FIPS 203 292// (DRAFT), Algorithm 13. 293func parseEK(ex *encryptionKey, ekPKE []byte) error { 294 if len(ekPKE) != encryptionKeySize { 295 return errors.New("mlkem768: invalid encryption key length") 296 } 297 298 for i := range ex.t { 299 var err error 300 ex.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12]) 301 if err != nil { 302 return err 303 } 304 ekPKE = ekPKE[encodingSize12:] 305 } 306 ρ := ekPKE 307 308 for i := byte(0); i < k; i++ { 309 for j := byte(0); j < k; j++ { 310 // See the note in pkeKeyGen about the order of the indices being 311 // consistent with Kyber round 3. 312 ex.A[i*k+j] = sampleNTT(ρ, j, i) 313 } 314 } 315 316 return nil 317} 318 319// pkeEncrypt encrypt a plaintext message. 320// 321// It implements K-PKE.Encrypt according to FIPS 203 (DRAFT), Algorithm 13, 322// although the computation of t and AT is done in parseEK. 323func pkeEncrypt(cc *[CiphertextSize]byte, ex *encryptionKey, m *[messageSize]byte, rnd []byte) []byte { 324 var N byte 325 r, e1 := make([]nttElement, k), make([]ringElement, k) 326 for i := range r { 327 r[i] = ntt(samplePolyCBD(rnd, N)) 328 N++ 329 } 330 for i := range e1 { 331 e1[i] = samplePolyCBD(rnd, N) 332 N++ 333 } 334 e2 := samplePolyCBD(rnd, N) 335 336 u := make([]ringElement, k) // NTT⁻¹(AT ◦ r) + e1 337 for i := range u { 338 u[i] = e1[i] 339 for j := range r { 340 // Note that i and j are inverted, as we need the transposed of A. 341 u[i] = polyAdd(u[i], inverseNTT(nttMul(ex.A[j*k+i], r[j]))) 342 } 343 } 344 345 μ := ringDecodeAndDecompress1(m) 346 347 var vNTT nttElement // t⊺ ◦ r 348 for i := range ex.t { 349 vNTT = polyAdd(vNTT, nttMul(ex.t[i], r[i])) 350 } 351 v := polyAdd(polyAdd(inverseNTT(vNTT), e2), μ) 352 353 c := cc[:0] 354 for _, f := range u { 355 c = ringCompressAndEncode10(c, f) 356 } 357 c = ringCompressAndEncode4(c, v) 358 359 return c 360} 361 362// Decapsulate generates a shared key from a ciphertext and a decapsulation key. 363// If the ciphertext is not valid, Decapsulate returns an error. 364// 365// The shared key must be kept secret. 366func Decapsulate(dk *DecapsulationKey, ciphertext []byte) (sharedKey []byte, err error) { 367 if len(ciphertext) != CiphertextSize { 368 return nil, errors.New("mlkem768: invalid ciphertext length") 369 } 370 c := (*[CiphertextSize]byte)(ciphertext) 371 return kemDecaps(dk, c), nil 372} 373 374// kemDecaps produces a shared key from a ciphertext. 375// 376// It implements ML-KEM.Decaps according to FIPS 203 (DRAFT), Algorithm 17. 377func kemDecaps(dk *DecapsulationKey, c *[CiphertextSize]byte) (K []byte) { 378 h := dk.dk[decryptionKeySize+encryptionKeySize : decryptionKeySize+encryptionKeySize+32] 379 z := dk.dk[decryptionKeySize+encryptionKeySize+32:] 380 381 m := pkeDecrypt(&dk.decryptionKey, c) 382 g := sha3.New512() 383 g.Write(m[:]) 384 g.Write(h) 385 G := g.Sum(nil) 386 Kprime, r := G[:SharedKeySize], G[SharedKeySize:] 387 J := sha3.NewShake256() 388 J.Write(z) 389 J.Write(c[:]) 390 Kout := make([]byte, SharedKeySize) 391 J.Read(Kout) 392 var cc [CiphertextSize]byte 393 c1 := pkeEncrypt(&cc, &dk.encryptionKey, (*[32]byte)(m), r) 394 395 subtle.ConstantTimeCopy(subtle.ConstantTimeCompare(c[:], c1), Kout, Kprime) 396 return Kout 397} 398 399// parseDK parses a decryption key from its encoded form. 400// 401// It implements the computation of s from K-PKE.Decrypt according to FIPS 203 402// (DRAFT), Algorithm 14. 403func parseDK(dx *decryptionKey, dkPKE []byte) error { 404 if len(dkPKE) != decryptionKeySize { 405 return errors.New("mlkem768: invalid decryption key length") 406 } 407 408 for i := range dx.s { 409 f, err := polyByteDecode[nttElement](dkPKE[:encodingSize12]) 410 if err != nil { 411 return err 412 } 413 dx.s[i] = f 414 dkPKE = dkPKE[encodingSize12:] 415 } 416 417 return nil 418} 419 420// pkeDecrypt decrypts a ciphertext. 421// 422// It implements K-PKE.Decrypt according to FIPS 203 (DRAFT), Algorithm 14, 423// although the computation of s is done in parseDK. 424func pkeDecrypt(dx *decryptionKey, c *[CiphertextSize]byte) []byte { 425 u := make([]ringElement, k) 426 for i := range u { 427 b := (*[encodingSize10]byte)(c[encodingSize10*i : encodingSize10*(i+1)]) 428 u[i] = ringDecodeAndDecompress10(b) 429 } 430 431 b := (*[encodingSize4]byte)(c[encodingSize10*k:]) 432 v := ringDecodeAndDecompress4(b) 433 434 var mask nttElement // s⊺ ◦ NTT(u) 435 for i := range dx.s { 436 mask = polyAdd(mask, nttMul(dx.s[i], ntt(u[i]))) 437 } 438 w := polySub(v, inverseNTT(mask)) 439 440 return ringCompressAndEncode1(nil, w) 441} 442 443// fieldElement is an integer modulo q, an element of ℤ_q. It is always reduced. 444type fieldElement uint16 445 446// fieldCheckReduced checks that a value a is < q. 447func fieldCheckReduced(a uint16) (fieldElement, error) { 448 if a >= q { 449 return 0, errors.New("unreduced field element") 450 } 451 return fieldElement(a), nil 452} 453 454// fieldReduceOnce reduces a value a < 2q. 455func fieldReduceOnce(a uint16) fieldElement { 456 x := a - q 457 // If x underflowed, then x >= 2¹⁶ - q > 2¹⁵, so the top bit is set. 458 x += (x >> 15) * q 459 return fieldElement(x) 460} 461 462func fieldAdd(a, b fieldElement) fieldElement { 463 x := uint16(a + b) 464 return fieldReduceOnce(x) 465} 466 467func fieldSub(a, b fieldElement) fieldElement { 468 x := uint16(a - b + q) 469 return fieldReduceOnce(x) 470} 471 472const ( 473 barrettMultiplier = 5039 // 2¹² * 2¹² / q 474 barrettShift = 24 // log₂(2¹² * 2¹²) 475) 476 477// fieldReduce reduces a value a < 2q² using Barrett reduction, to avoid 478// potentially variable-time division. 479func fieldReduce(a uint32) fieldElement { 480 quotient := uint32((uint64(a) * barrettMultiplier) >> barrettShift) 481 return fieldReduceOnce(uint16(a - quotient*q)) 482} 483 484func fieldMul(a, b fieldElement) fieldElement { 485 x := uint32(a) * uint32(b) 486 return fieldReduce(x) 487} 488 489// fieldMulSub returns a * (b - c). This operation is fused to save a 490// fieldReduceOnce after the subtraction. 491func fieldMulSub(a, b, c fieldElement) fieldElement { 492 x := uint32(a) * uint32(b-c+q) 493 return fieldReduce(x) 494} 495 496// fieldAddMul returns a * b + c * d. This operation is fused to save a 497// fieldReduceOnce and a fieldReduce. 498func fieldAddMul(a, b, c, d fieldElement) fieldElement { 499 x := uint32(a) * uint32(b) 500 x += uint32(c) * uint32(d) 501 return fieldReduce(x) 502} 503 504// compress maps a field element uniformly to the range 0 to 2ᵈ-1, according to 505// FIPS 203 (DRAFT), Definition 4.5. 506func compress(x fieldElement, d uint8) uint16 { 507 // We want to compute (x * 2ᵈ) / q, rounded to nearest integer, with 1/2 508 // rounding up (see FIPS 203 (DRAFT), Section 2.3). 509 510 // Barrett reduction produces a quotient and a remainder in the range [0, 2q), 511 // such that dividend = quotient * q + remainder. 512 dividend := uint32(x) << d // x * 2ᵈ 513 quotient := uint32(uint64(dividend) * barrettMultiplier >> barrettShift) 514 remainder := dividend - quotient*q 515 516 // Since the remainder is in the range [0, 2q), not [0, q), we need to 517 // portion it into three spans for rounding. 518 // 519 // [ 0, q/2 ) -> round to 0 520 // [ q/2, q + q/2 ) -> round to 1 521 // [ q + q/2, 2q ) -> round to 2 522 // 523 // We can convert that to the following logic: add 1 if remainder > q/2, 524 // then add 1 again if remainder > q + q/2. 525 // 526 // Note that if remainder > x, then ⌊x⌋ - remainder underflows, and the top 527 // bit of the difference will be set. 528 quotient += (q/2 - remainder) >> 31 & 1 529 quotient += (q + q/2 - remainder) >> 31 & 1 530 531 // quotient might have overflowed at this point, so reduce it by masking. 532 var mask uint32 = (1 << d) - 1 533 return uint16(quotient & mask) 534} 535 536// decompress maps a number x between 0 and 2ᵈ-1 uniformly to the full range of 537// field elements, according to FIPS 203 (DRAFT), Definition 4.6. 538func decompress(y uint16, d uint8) fieldElement { 539 // We want to compute (y * q) / 2ᵈ, rounded to nearest integer, with 1/2 540 // rounding up (see FIPS 203 (DRAFT), Section 2.3). 541 542 dividend := uint32(y) * q 543 quotient := dividend >> d // (y * q) / 2ᵈ 544 545 // The d'th least-significant bit of the dividend (the most significant bit 546 // of the remainder) is 1 for the top half of the values that divide to the 547 // same quotient, which are the ones that round up. 548 quotient += dividend >> (d - 1) & 1 549 550 // quotient is at most (2¹¹-1) * q / 2¹¹ + 1 = 3328, so it didn't overflow. 551 return fieldElement(quotient) 552} 553 554// ringElement is a polynomial, an element of R_q, represented as an array 555// according to FIPS 203 (DRAFT), Section 2.4. 556type ringElement [n]fieldElement 557 558// polyAdd adds two ringElements or nttElements. 559func polyAdd[T ~[n]fieldElement](a, b T) (s T) { 560 for i := range s { 561 s[i] = fieldAdd(a[i], b[i]) 562 } 563 return s 564} 565 566// polySub subtracts two ringElements or nttElements. 567func polySub[T ~[n]fieldElement](a, b T) (s T) { 568 for i := range s { 569 s[i] = fieldSub(a[i], b[i]) 570 } 571 return s 572} 573 574// polyByteEncode appends the 384-byte encoding of f to b. 575// 576// It implements ByteEncode₁₂, according to FIPS 203 (DRAFT), Algorithm 4. 577func polyByteEncode[T ~[n]fieldElement](b []byte, f T) []byte { 578 out, B := sliceForAppend(b, encodingSize12) 579 for i := 0; i < n; i += 2 { 580 x := uint32(f[i]) | uint32(f[i+1])<<12 581 B[0] = uint8(x) 582 B[1] = uint8(x >> 8) 583 B[2] = uint8(x >> 16) 584 B = B[3:] 585 } 586 return out 587} 588 589// polyByteDecode decodes the 384-byte encoding of a polynomial, checking that 590// all the coefficients are properly reduced. This achieves the "Modulus check" 591// step of ML-KEM Encapsulation Input Validation. 592// 593// polyByteDecode is also used in ML-KEM Decapsulation, where the input 594// validation is not required, but implicitly allowed by the specification. 595// 596// It implements ByteDecode₁₂, according to FIPS 203 (DRAFT), Algorithm 5. 597func polyByteDecode[T ~[n]fieldElement](b []byte) (T, error) { 598 if len(b) != encodingSize12 { 599 return T{}, errors.New("mlkem768: invalid encoding length") 600 } 601 var f T 602 for i := 0; i < n; i += 2 { 603 d := uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 604 const mask12 = 0b1111_1111_1111 605 var err error 606 if f[i], err = fieldCheckReduced(uint16(d & mask12)); err != nil { 607 return T{}, errors.New("mlkem768: invalid polynomial encoding") 608 } 609 if f[i+1], err = fieldCheckReduced(uint16(d >> 12)); err != nil { 610 return T{}, errors.New("mlkem768: invalid polynomial encoding") 611 } 612 b = b[3:] 613 } 614 return f, nil 615} 616 617// sliceForAppend takes a slice and a requested number of bytes. It returns a 618// slice with the contents of the given slice followed by that many bytes and a 619// second slice that aliases into it and contains only the extra bytes. If the 620// original slice has sufficient capacity then no allocation is performed. 621func sliceForAppend(in []byte, n int) (head, tail []byte) { 622 if total := len(in) + n; cap(in) >= total { 623 head = in[:total] 624 } else { 625 head = make([]byte, total) 626 copy(head, in) 627 } 628 tail = head[len(in):] 629 return 630} 631 632// ringCompressAndEncode1 appends a 32-byte encoding of a ring element to s, 633// compressing one coefficients per bit. 634// 635// It implements Compress₁, according to FIPS 203 (DRAFT), Definition 4.5, 636// followed by ByteEncode₁, according to FIPS 203 (DRAFT), Algorithm 4. 637func ringCompressAndEncode1(s []byte, f ringElement) []byte { 638 s, b := sliceForAppend(s, encodingSize1) 639 for i := range b { 640 b[i] = 0 641 } 642 for i := range f { 643 b[i/8] |= uint8(compress(f[i], 1) << (i % 8)) 644 } 645 return s 646} 647 648// ringDecodeAndDecompress1 decodes a 32-byte slice to a ring element where each 649// bit is mapped to 0 or ⌈q/2⌋. 650// 651// It implements ByteDecode₁, according to FIPS 203 (DRAFT), Algorithm 5, 652// followed by Decompress₁, according to FIPS 203 (DRAFT), Definition 4.6. 653func ringDecodeAndDecompress1(b *[encodingSize1]byte) ringElement { 654 var f ringElement 655 for i := range f { 656 b_i := b[i/8] >> (i % 8) & 1 657 const halfQ = (q + 1) / 2 // ⌈q/2⌋, rounded up per FIPS 203 (DRAFT), Section 2.3 658 f[i] = fieldElement(b_i) * halfQ // 0 decompresses to 0, and 1 to ⌈q/2⌋ 659 } 660 return f 661} 662 663// ringCompressAndEncode4 appends a 128-byte encoding of a ring element to s, 664// compressing two coefficients per byte. 665// 666// It implements Compress₄, according to FIPS 203 (DRAFT), Definition 4.5, 667// followed by ByteEncode₄, according to FIPS 203 (DRAFT), Algorithm 4. 668func ringCompressAndEncode4(s []byte, f ringElement) []byte { 669 s, b := sliceForAppend(s, encodingSize4) 670 for i := 0; i < n; i += 2 { 671 b[i/2] = uint8(compress(f[i], 4) | compress(f[i+1], 4)<<4) 672 } 673 return s 674} 675 676// ringDecodeAndDecompress4 decodes a 128-byte encoding of a ring element where 677// each four bits are mapped to an equidistant distribution. 678// 679// It implements ByteDecode₄, according to FIPS 203 (DRAFT), Algorithm 5, 680// followed by Decompress₄, according to FIPS 203 (DRAFT), Definition 4.6. 681func ringDecodeAndDecompress4(b *[encodingSize4]byte) ringElement { 682 var f ringElement 683 for i := 0; i < n; i += 2 { 684 f[i] = fieldElement(decompress(uint16(b[i/2]&0b1111), 4)) 685 f[i+1] = fieldElement(decompress(uint16(b[i/2]>>4), 4)) 686 } 687 return f 688} 689 690// ringCompressAndEncode10 appends a 320-byte encoding of a ring element to s, 691// compressing four coefficients per five bytes. 692// 693// It implements Compress₁₀, according to FIPS 203 (DRAFT), Definition 4.5, 694// followed by ByteEncode₁₀, according to FIPS 203 (DRAFT), Algorithm 4. 695func ringCompressAndEncode10(s []byte, f ringElement) []byte { 696 s, b := sliceForAppend(s, encodingSize10) 697 for i := 0; i < n; i += 4 { 698 var x uint64 699 x |= uint64(compress(f[i+0], 10)) 700 x |= uint64(compress(f[i+1], 10)) << 10 701 x |= uint64(compress(f[i+2], 10)) << 20 702 x |= uint64(compress(f[i+3], 10)) << 30 703 b[0] = uint8(x) 704 b[1] = uint8(x >> 8) 705 b[2] = uint8(x >> 16) 706 b[3] = uint8(x >> 24) 707 b[4] = uint8(x >> 32) 708 b = b[5:] 709 } 710 return s 711} 712 713// ringDecodeAndDecompress10 decodes a 320-byte encoding of a ring element where 714// each ten bits are mapped to an equidistant distribution. 715// 716// It implements ByteDecode₁₀, according to FIPS 203 (DRAFT), Algorithm 5, 717// followed by Decompress₁₀, according to FIPS 203 (DRAFT), Definition 4.6. 718func ringDecodeAndDecompress10(bb *[encodingSize10]byte) ringElement { 719 b := bb[:] 720 var f ringElement 721 for i := 0; i < n; i += 4 { 722 x := uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | uint64(b[4])<<32 723 b = b[5:] 724 f[i] = fieldElement(decompress(uint16(x>>0&0b11_1111_1111), 10)) 725 f[i+1] = fieldElement(decompress(uint16(x>>10&0b11_1111_1111), 10)) 726 f[i+2] = fieldElement(decompress(uint16(x>>20&0b11_1111_1111), 10)) 727 f[i+3] = fieldElement(decompress(uint16(x>>30&0b11_1111_1111), 10)) 728 } 729 return f 730} 731 732// samplePolyCBD draws a ringElement from the special Dη distribution given a 733// stream of random bytes generated by the PRF function, according to FIPS 203 734// (DRAFT), Algorithm 7 and Definition 4.1. 735func samplePolyCBD(s []byte, b byte) ringElement { 736 prf := sha3.NewShake256() 737 prf.Write(s) 738 prf.Write([]byte{b}) 739 B := make([]byte, 128) 740 prf.Read(B) 741 742 // SamplePolyCBD simply draws four (2η) bits for each coefficient, and adds 743 // the first two and subtracts the last two. 744 745 var f ringElement 746 for i := 0; i < n; i += 2 { 747 b := B[i/2] 748 b_7, b_6, b_5, b_4 := b>>7, b>>6&1, b>>5&1, b>>4&1 749 b_3, b_2, b_1, b_0 := b>>3&1, b>>2&1, b>>1&1, b&1 750 f[i] = fieldSub(fieldElement(b_0+b_1), fieldElement(b_2+b_3)) 751 f[i+1] = fieldSub(fieldElement(b_4+b_5), fieldElement(b_6+b_7)) 752 } 753 return f 754} 755 756// nttElement is an NTT representation, an element of T_q, represented as an 757// array according to FIPS 203 (DRAFT), Section 2.4. 758type nttElement [n]fieldElement 759 760// gammas are the values ζ^2BitRev7(i)+1 mod q for each index i. 761var gammas = [128]fieldElement{17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606, 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096, 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678, 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642, 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992, 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109, 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010, 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735, 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179, 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300, 2110, 1219, 2935, 394, 885, 2444, 2154, 1175} 762 763// nttMul multiplies two nttElements. 764// 765// It implements MultiplyNTTs, according to FIPS 203 (DRAFT), Algorithm 10. 766func nttMul(f, g nttElement) nttElement { 767 var h nttElement 768 // We use i += 2 for bounds check elimination. See https://go.dev/issue/66826. 769 for i := 0; i < 256; i += 2 { 770 a0, a1 := f[i], f[i+1] 771 b0, b1 := g[i], g[i+1] 772 h[i] = fieldAddMul(a0, b0, fieldMul(a1, b1), gammas[i/2]) 773 h[i+1] = fieldAddMul(a0, b1, a1, b0) 774 } 775 return h 776} 777 778// zetas are the values ζ^BitRev7(k) mod q for each index k. 779var zetas = [128]fieldElement{1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756, 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915, 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648, 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100, 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641, 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757, 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154} 780 781// ntt maps a ringElement to its nttElement representation. 782// 783// It implements NTT, according to FIPS 203 (DRAFT), Algorithm 8. 784func ntt(f ringElement) nttElement { 785 k := 1 786 for len := 128; len >= 2; len /= 2 { 787 for start := 0; start < 256; start += 2 * len { 788 zeta := zetas[k] 789 k++ 790 // Bounds check elimination hint. 791 f, flen := f[start:start+len], f[start+len:start+len+len] 792 for j := 0; j < len; j++ { 793 t := fieldMul(zeta, flen[j]) 794 flen[j] = fieldSub(f[j], t) 795 f[j] = fieldAdd(f[j], t) 796 } 797 } 798 } 799 return nttElement(f) 800} 801 802// inverseNTT maps a nttElement back to the ringElement it represents. 803// 804// It implements NTT⁻¹, according to FIPS 203 (DRAFT), Algorithm 9. 805func inverseNTT(f nttElement) ringElement { 806 k := 127 807 for len := 2; len <= 128; len *= 2 { 808 for start := 0; start < 256; start += 2 * len { 809 zeta := zetas[k] 810 k-- 811 // Bounds check elimination hint. 812 f, flen := f[start:start+len], f[start+len:start+len+len] 813 for j := 0; j < len; j++ { 814 t := f[j] 815 f[j] = fieldAdd(t, flen[j]) 816 flen[j] = fieldMulSub(zeta, flen[j], t) 817 } 818 } 819 } 820 for i := range f { 821 f[i] = fieldMul(f[i], 3303) // 3303 = 128⁻¹ mod q 822 } 823 return ringElement(f) 824} 825 826// sampleNTT draws a uniformly random nttElement from a stream of uniformly 827// random bytes generated by the XOF function, according to FIPS 203 (DRAFT), 828// Algorithm 6 and Definition 4.2. 829func sampleNTT(rho []byte, ii, jj byte) nttElement { 830 B := sha3.NewShake128() 831 B.Write(rho) 832 B.Write([]byte{ii, jj}) 833 834 // SampleNTT essentially draws 12 bits at a time from r, interprets them in 835 // little-endian, and rejects values higher than q, until it drew 256 836 // values. (The rejection rate is approximately 19%.) 837 // 838 // To do this from a bytes stream, it draws three bytes at a time, and 839 // splits them into two uint16 appropriately masked. 840 // 841 // r₀ r₁ r₂ 842 // |- - - - - - - -|- - - - - - - -|- - - - - - - -| 843 // 844 // Uint16(r₀ || r₁) 845 // |- - - - - - - - - - - - - - - -| 846 // |- - - - - - - - - - - -| 847 // d₁ 848 // 849 // Uint16(r₁ || r₂) 850 // |- - - - - - - - - - - - - - - -| 851 // |- - - - - - - - - - - -| 852 // d₂ 853 // 854 // Note that in little-endian, the rightmost bits are the most significant 855 // bits (dropped with a mask) and the leftmost bits are the least 856 // significant bits (dropped with a right shift). 857 858 var a nttElement 859 var j int // index into a 860 var buf [24]byte // buffered reads from B 861 off := len(buf) // index into buf, starts in a "buffer fully consumed" state 862 for { 863 if off >= len(buf) { 864 B.Read(buf[:]) 865 off = 0 866 } 867 d1 := byteorder.LeUint16(buf[off:]) & 0b1111_1111_1111 868 d2 := byteorder.LeUint16(buf[off+1:]) >> 4 869 off += 3 870 if d1 < q { 871 a[j] = fieldElement(d1) 872 j++ 873 } 874 if j >= len(a) { 875 break 876 } 877 if d2 < q { 878 a[j] = fieldElement(d2) 879 j++ 880 } 881 if j >= len(a) { 882 break 883 } 884 } 885 return a 886} 887