xref: /aosp_15_r20/external/cronet/third_party/boringssl/src/crypto/kyber/kyber.c (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 /* Copyright (c) 2023, Google Inc.
2  *
3  * Permission to use, copy, modify, and/or distribute this software for any
4  * purpose with or without fee is hereby granted, provided that the above
5  * copyright notice and this permission notice appear in all copies.
6  *
7  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10  * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12  * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13  * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
14 
15 #include <openssl/experimental/kyber.h>
16 
17 #include <assert.h>
18 #include <stdlib.h>
19 
20 #include <openssl/bytestring.h>
21 #include <openssl/rand.h>
22 
23 #include "../internal.h"
24 #include "../keccak/internal.h"
25 #include "./internal.h"
26 
27 
28 // See
29 // https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
30 
prf(uint8_t * out,size_t out_len,const uint8_t in[33])31 static void prf(uint8_t *out, size_t out_len, const uint8_t in[33]) {
32   BORINGSSL_keccak(out, out_len, in, 33, boringssl_shake256);
33 }
34 
hash_h(uint8_t out[32],const uint8_t * in,size_t len)35 static void hash_h(uint8_t out[32], const uint8_t *in, size_t len) {
36   BORINGSSL_keccak(out, 32, in, len, boringssl_sha3_256);
37 }
38 
hash_g(uint8_t out[64],const uint8_t * in,size_t len)39 static void hash_g(uint8_t out[64], const uint8_t *in, size_t len) {
40   BORINGSSL_keccak(out, 64, in, len, boringssl_sha3_512);
41 }
42 
kdf(uint8_t * out,size_t out_len,const uint8_t * in,size_t len)43 static void kdf(uint8_t *out, size_t out_len, const uint8_t *in, size_t len) {
44   BORINGSSL_keccak(out, out_len, in, len, boringssl_shake256);
45 }
46 
47 #define DEGREE 256
48 #define RANK 3
49 
50 static const size_t kBarrettMultiplier = 5039;
51 static const unsigned kBarrettShift = 24;
52 static const uint16_t kPrime = 3329;
53 static const int kLog2Prime = 12;
54 static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2;
55 static const int kDU = 10;
56 static const int kDV = 4;
57 // kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th
58 // root of unity.
59 static const uint16_t kInverseDegree = 3303;
60 static const size_t kEncodedVectorSize =
61     (/*kLog2Prime=*/12 * DEGREE / 8) * RANK;
62 static const size_t kCompressedVectorSize = /*kDU=*/10 * RANK * DEGREE / 8;
63 
64 typedef struct scalar {
65   // On every function entry and exit, 0 <= c < kPrime.
66   uint16_t c[DEGREE];
67 } scalar;
68 
69 typedef struct vector {
70   scalar v[RANK];
71 } vector;
72 
73 typedef struct matrix {
74   scalar v[RANK][RANK];
75 } matrix;
76 
77 // This bit of Python will be referenced in some of the following comments:
78 //
79 // p = 3329
80 //
81 // def bitreverse(i):
82 //     ret = 0
83 //     for n in range(7):
84 //         bit = i & 1
85 //         ret <<= 1
86 //         ret |= bit
87 //         i >>= 1
88 //     return ret
89 
90 // kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
91 static const uint16_t kNTTRoots[128] = {
92     1,    1729, 2580, 3289, 2642, 630,  1897, 848,  1062, 1919, 193,  797,
93     2786, 3260, 569,  1746, 296,  2447, 1339, 1476, 3046, 56,   2240, 1333,
94     1426, 2094, 535,  2882, 2393, 2879, 1974, 821,  289,  331,  3253, 1756,
95     1197, 2304, 2277, 2055, 650,  1977, 2513, 632,  2865, 33,   1320, 1915,
96     2319, 1435, 807,  452,  1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
97     2474, 3110, 1227, 910,  17,   2761, 583,  2649, 1637, 723,  2288, 1100,
98     1409, 2662, 3281, 233,  756,  2156, 3015, 3050, 1703, 1651, 2789, 1789,
99     1847, 952,  1461, 2687, 939,  2308, 2437, 2388, 733,  2337, 268,  641,
100     1584, 2298, 2037, 3220, 375,  2549, 2090, 1645, 1063, 319,  2773, 757,
101     2099, 561,  2466, 2594, 2804, 1092, 403,  1026, 1143, 2150, 2775, 886,
102     1722, 1212, 1874, 1029, 2110, 2935, 885,  2154,
103 };
104 
105 // kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
106 static const uint16_t kInverseNTTRoots[128] = {
107     1,    1600, 40,   749,  2481, 1432, 2699, 687,  1583, 2760, 69,   543,
108     2532, 3136, 1410, 2267, 2508, 1355, 450,  936,  447,  2794, 1235, 1903,
109     1996, 1089, 3273, 283,  1853, 1990, 882,  3033, 2419, 2102, 219,  855,
110     2681, 1848, 712,  682,  927,  1795, 461,  1891, 2877, 2522, 1894, 1010,
111     1414, 2009, 3296, 464,  2697, 816,  1352, 2679, 1274, 1052, 1025, 2132,
112     1573, 76,   2998, 3040, 1175, 2444, 394,  1219, 2300, 1455, 2117, 1607,
113     2443, 554,  1179, 2186, 2303, 2926, 2237, 525,  735,  863,  2768, 1230,
114     2572, 556,  3010, 2266, 1684, 1239, 780,  2954, 109,  1292, 1031, 1745,
115     2688, 3061, 992,  2596, 941,  892,  1021, 2390, 642,  1868, 2377, 1482,
116     1540, 540,  1678, 1626, 279,  314,  1173, 2573, 3096, 48,   667,  1920,
117     2229, 1041, 2606, 1692, 680,  2746, 568,  3312,
118 };
119 
120 // kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
121 static const uint16_t kModRoots[128] = {
122     17,   3312, 2761, 568,  583,  2746, 2649, 680,  1637, 1692, 723,  2606,
123     2288, 1041, 1100, 2229, 1409, 1920, 2662, 667,  3281, 48,   233,  3096,
124     756,  2573, 2156, 1173, 3015, 314,  3050, 279,  1703, 1626, 1651, 1678,
125     2789, 540,  1789, 1540, 1847, 1482, 952,  2377, 1461, 1868, 2687, 642,
126     939,  2390, 2308, 1021, 2437, 892,  2388, 941,  733,  2596, 2337, 992,
127     268,  3061, 641,  2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
128     375,  2954, 2549, 780,  2090, 1239, 1645, 1684, 1063, 2266, 319,  3010,
129     2773, 556,  757,  2572, 2099, 1230, 561,  2768, 2466, 863,  2594, 735,
130     2804, 525,  1092, 2237, 403,  2926, 1026, 2303, 1143, 2186, 2150, 1179,
131     2775, 554,  886,  2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
132     2110, 1219, 2935, 394,  885,  2444, 2154, 1175,
133 };
134 
135 // reduce_once reduces 0 <= x < 2*kPrime, mod kPrime.
reduce_once(uint16_t x)136 static uint16_t reduce_once(uint16_t x) {
137   assert(x < 2 * kPrime);
138   const uint16_t subtracted = x - kPrime;
139   uint16_t mask = 0u - (subtracted >> 15);
140   // On Aarch64, omitting a |value_barrier_u16| results in a 2x speedup of Kyber
141   // overall and Clang still produces constant-time code using `csel`. On other
142   // platforms & compilers on godbolt that we care about, this code also
143   // produces constant-time output.
144   return (mask & x) | (~mask & subtracted);
145 }
146 
147 // constant time reduce x mod kPrime using Barrett reduction. x must be less
148 // than kPrime + 2×kPrime².
reduce(uint32_t x)149 static uint16_t reduce(uint32_t x) {
150   assert(x < kPrime + 2u * kPrime * kPrime);
151   uint64_t product = (uint64_t)x * kBarrettMultiplier;
152   uint32_t quotient = (uint32_t)(product >> kBarrettShift);
153   uint32_t remainder = x - quotient * kPrime;
154   return reduce_once(remainder);
155 }
156 
scalar_zero(scalar * out)157 static void scalar_zero(scalar *out) { OPENSSL_memset(out, 0, sizeof(*out)); }
158 
vector_zero(vector * out)159 static void vector_zero(vector *out) { OPENSSL_memset(out, 0, sizeof(*out)); }
160 
161 // In place number theoretic transform of a given scalar.
162 // Note that Kyber's kPrime 3329 does not have a 512th root of unity, so this
163 // transform leaves off the last iteration of the usual FFT code, with the 128
164 // relevant roots of unity being stored in |kNTTRoots|. This means the output
165 // should be seen as 128 elements in GF(3329^2), with the coefficients of the
166 // elements being consecutive entries in |s->c|.
scalar_ntt(scalar * s)167 static void scalar_ntt(scalar *s) {
168   int offset = DEGREE;
169   // `int` is used here because using `size_t` throughout caused a ~5% slowdown
170   // with Clang 14 on Aarch64.
171   for (int step = 1; step < DEGREE / 2; step <<= 1) {
172     offset >>= 1;
173     int k = 0;
174     for (int i = 0; i < step; i++) {
175       const uint32_t step_root = kNTTRoots[i + step];
176       for (int j = k; j < k + offset; j++) {
177         uint16_t odd = reduce(step_root * s->c[j + offset]);
178         uint16_t even = s->c[j];
179         s->c[j] = reduce_once(odd + even);
180         s->c[j + offset] = reduce_once(even - odd + kPrime);
181       }
182       k += 2 * offset;
183     }
184   }
185 }
186 
vector_ntt(vector * a)187 static void vector_ntt(vector *a) {
188   for (int i = 0; i < RANK; i++) {
189     scalar_ntt(&a->v[i]);
190   }
191 }
192 
193 // In place inverse number theoretic transform of a given scalar, with pairs of
194 // entries of s->v being interpreted as elements of GF(3329^2). Just as with the
195 // number theoretic transform, this leaves off the first step of the normal iFFT
196 // to account for the fact that 3329 does not have a 512th root of unity, using
197 // the precomputed 128 roots of unity stored in |kInverseNTTRoots|.
scalar_inverse_ntt(scalar * s)198 static void scalar_inverse_ntt(scalar *s) {
199   int step = DEGREE / 2;
200   // `int` is used here because using `size_t` throughout caused a ~5% slowdown
201   // with Clang 14 on Aarch64.
202   for (int offset = 2; offset < DEGREE; offset <<= 1) {
203     step >>= 1;
204     int k = 0;
205     for (int i = 0; i < step; i++) {
206       uint32_t step_root = kInverseNTTRoots[i + step];
207       for (int j = k; j < k + offset; j++) {
208         uint16_t odd = s->c[j + offset];
209         uint16_t even = s->c[j];
210         s->c[j] = reduce_once(odd + even);
211         s->c[j + offset] = reduce(step_root * (even - odd + kPrime));
212       }
213       k += 2 * offset;
214     }
215   }
216   for (int i = 0; i < DEGREE; i++) {
217     s->c[i] = reduce(s->c[i] * kInverseDegree);
218   }
219 }
220 
vector_inverse_ntt(vector * a)221 static void vector_inverse_ntt(vector *a) {
222   for (int i = 0; i < RANK; i++) {
223     scalar_inverse_ntt(&a->v[i]);
224   }
225 }
226 
scalar_add(scalar * lhs,const scalar * rhs)227 static void scalar_add(scalar *lhs, const scalar *rhs) {
228   for (int i = 0; i < DEGREE; i++) {
229     lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
230   }
231 }
232 
scalar_sub(scalar * lhs,const scalar * rhs)233 static void scalar_sub(scalar *lhs, const scalar *rhs) {
234   for (int i = 0; i < DEGREE; i++) {
235     lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
236   }
237 }
238 
239 // Multiplying two scalars in the number theoretically transformed state. Since
240 // 3329 does not have a 512th root of unity, this means we have to interpret
241 // the 2*ith and (2*i+1)th entries of the scalar as elements of GF(3329)[X]/(X^2
242 // - 17^(2*bitreverse(i)+1)) The value of 17^(2*bitreverse(i)+1) mod 3329 is
243 // stored in the precomputed |kModRoots| table. Note that our Barrett transform
244 // only allows us to multipy two reduced numbers together, so we need some
245 // intermediate reduction steps, even if an uint64_t could hold 3 multiplied
246 // numbers.
scalar_mult(scalar * out,const scalar * lhs,const scalar * rhs)247 static void scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs) {
248   for (int i = 0; i < DEGREE / 2; i++) {
249     uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i];
250     uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i + 1];
251     uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1];
252     uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i];
253     out->c[2 * i] =
254         reduce(real_real + (uint32_t)reduce(img_img) * kModRoots[i]);
255     out->c[2 * i + 1] = reduce(img_real + real_img);
256   }
257 }
258 
vector_add(vector * lhs,const vector * rhs)259 static void vector_add(vector *lhs, const vector *rhs) {
260   for (int i = 0; i < RANK; i++) {
261     scalar_add(&lhs->v[i], &rhs->v[i]);
262   }
263 }
264 
matrix_mult(vector * out,const matrix * m,const vector * a)265 static void matrix_mult(vector *out, const matrix *m, const vector *a) {
266   vector_zero(out);
267   for (int i = 0; i < RANK; i++) {
268     for (int j = 0; j < RANK; j++) {
269       scalar product;
270       scalar_mult(&product, &m->v[i][j], &a->v[j]);
271       scalar_add(&out->v[i], &product);
272     }
273   }
274 }
275 
matrix_mult_transpose(vector * out,const matrix * m,const vector * a)276 static void matrix_mult_transpose(vector *out, const matrix *m,
277                                   const vector *a) {
278   vector_zero(out);
279   for (int i = 0; i < RANK; i++) {
280     for (int j = 0; j < RANK; j++) {
281       scalar product;
282       scalar_mult(&product, &m->v[j][i], &a->v[j]);
283       scalar_add(&out->v[i], &product);
284     }
285   }
286 }
287 
scalar_inner_product(scalar * out,const vector * lhs,const vector * rhs)288 static void scalar_inner_product(scalar *out, const vector *lhs,
289                                  const vector *rhs) {
290   scalar_zero(out);
291   for (int i = 0; i < RANK; i++) {
292     scalar product;
293     scalar_mult(&product, &lhs->v[i], &rhs->v[i]);
294     scalar_add(out, &product);
295   }
296 }
297 
298 // Algorithm 1 of the Kyber spec. Rejection samples a Keccak stream to get
299 // uniformly distributed elements. This is used for matrix expansion and only
300 // operates on public inputs.
scalar_from_keccak_vartime(scalar * out,struct BORINGSSL_keccak_st * keccak_ctx)301 static void scalar_from_keccak_vartime(scalar *out,
302                                        struct BORINGSSL_keccak_st *keccak_ctx) {
303   assert(keccak_ctx->squeeze_offset == 0);
304   assert(keccak_ctx->rate_bytes == 168);
305   static_assert(168 % 3 == 0, "block and coefficient boundaries do not align");
306 
307   int done = 0;
308   while (done < DEGREE) {
309     uint8_t block[168];
310     BORINGSSL_keccak_squeeze(keccak_ctx, block, sizeof(block));
311     for (size_t i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
312       uint16_t d1 = block[i] + 256 * (block[i + 1] % 16);
313       uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2];
314       if (d1 < kPrime) {
315         out->c[done++] = d1;
316       }
317       if (d2 < kPrime && done < DEGREE) {
318         out->c[done++] = d2;
319       }
320     }
321   }
322 }
323 
324 // Algorithm 2 of the Kyber spec, with eta fixed to two and the PRF call
325 // included. Creates binominally distributed elements by sampling 2*|eta| bits,
326 // and setting the coefficient to the count of the first bits minus the count of
327 // the second bits, resulting in a centered binomial distribution. Since eta is
328 // two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
329 // and 0 with probability 3/8.
scalar_centered_binomial_distribution_eta_2_with_prf(scalar * out,const uint8_t input[33])330 static void scalar_centered_binomial_distribution_eta_2_with_prf(
331     scalar *out, const uint8_t input[33]) {
332   uint8_t entropy[128];
333   static_assert(sizeof(entropy) == 2 * /*kEta=*/2 * DEGREE / 8, "");
334   prf(entropy, sizeof(entropy), input);
335 
336   for (int i = 0; i < DEGREE; i += 2) {
337     uint8_t byte = entropy[i / 2];
338 
339     uint16_t value = kPrime;
340     value += (byte & 1) + ((byte >> 1) & 1);
341     value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
342     out->c[i] = reduce_once(value);
343 
344     byte >>= 4;
345     value = kPrime;
346     value += (byte & 1) + ((byte >> 1) & 1);
347     value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
348     out->c[i + 1] = reduce_once(value);
349   }
350 }
351 
352 // Generates a secret vector by using
353 // |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed
354 // appending and incrementing |counter| for entry of the vector.
vector_generate_secret_eta_2(vector * out,uint8_t * counter,const uint8_t seed[32])355 static void vector_generate_secret_eta_2(vector *out, uint8_t *counter,
356                                          const uint8_t seed[32]) {
357   uint8_t input[33];
358   OPENSSL_memcpy(input, seed, 32);
359   for (int i = 0; i < RANK; i++) {
360     input[32] = (*counter)++;
361     scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i], input);
362   }
363 }
364 
365 // Expands the matrix of a seed for key generation and for encaps-CPA.
matrix_expand(matrix * out,const uint8_t rho[32])366 static void matrix_expand(matrix *out, const uint8_t rho[32]) {
367   uint8_t input[34];
368   OPENSSL_memcpy(input, rho, 32);
369   for (int i = 0; i < RANK; i++) {
370     for (int j = 0; j < RANK; j++) {
371       input[32] = i;
372       input[33] = j;
373       struct BORINGSSL_keccak_st keccak_ctx;
374       BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake128);
375       BORINGSSL_keccak_absorb(&keccak_ctx, input, sizeof(input));
376       scalar_from_keccak_vartime(&out->v[i][j], &keccak_ctx);
377     }
378   }
379 }
380 
381 static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
382                                   0x1f, 0x3f, 0x7f, 0xff};
383 
scalar_encode(uint8_t * out,const scalar * s,int bits)384 static void scalar_encode(uint8_t *out, const scalar *s, int bits) {
385   assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
386 
387   uint8_t out_byte = 0;
388   int out_byte_bits = 0;
389 
390   for (int i = 0; i < DEGREE; i++) {
391     uint16_t element = s->c[i];
392     int element_bits_done = 0;
393 
394     while (element_bits_done < bits) {
395       int chunk_bits = bits - element_bits_done;
396       int out_bits_remaining = 8 - out_byte_bits;
397       if (chunk_bits >= out_bits_remaining) {
398         chunk_bits = out_bits_remaining;
399         out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
400         *out = out_byte;
401         out++;
402         out_byte_bits = 0;
403         out_byte = 0;
404       } else {
405         out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
406         out_byte_bits += chunk_bits;
407       }
408 
409       element_bits_done += chunk_bits;
410       element >>= chunk_bits;
411     }
412   }
413 
414   if (out_byte_bits > 0) {
415     *out = out_byte;
416   }
417 }
418 
419 // scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
scalar_encode_1(uint8_t out[32],const scalar * s)420 static void scalar_encode_1(uint8_t out[32], const scalar *s) {
421   for (int i = 0; i < DEGREE; i += 8) {
422     uint8_t out_byte = 0;
423     for (int j = 0; j < 8; j++) {
424       out_byte |= (s->c[i + j] & 1) << j;
425     }
426     *out = out_byte;
427     out++;
428   }
429 }
430 
431 // Encodes an entire vector into 32*|RANK|*|bits| bytes. Note that since 256
432 // (DEGREE) is divisible by 8, the individual vector entries will always fill a
433 // whole number of bytes, so we do not need to worry about bit packing here.
vector_encode(uint8_t * out,const vector * a,int bits)434 static void vector_encode(uint8_t *out, const vector *a, int bits) {
435   for (int i = 0; i < RANK; i++) {
436     scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits);
437   }
438 }
439 
440 // scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
441 // |out|. It returns one on success and zero if any parsed value is >=
442 // |kPrime|.
scalar_decode(scalar * out,const uint8_t * in,int bits)443 static int scalar_decode(scalar *out, const uint8_t *in, int bits) {
444   assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
445 
446   uint8_t in_byte = 0;
447   int in_byte_bits_left = 0;
448 
449   for (int i = 0; i < DEGREE; i++) {
450     uint16_t element = 0;
451     int element_bits_done = 0;
452 
453     while (element_bits_done < bits) {
454       if (in_byte_bits_left == 0) {
455         in_byte = *in;
456         in++;
457         in_byte_bits_left = 8;
458       }
459 
460       int chunk_bits = bits - element_bits_done;
461       if (chunk_bits > in_byte_bits_left) {
462         chunk_bits = in_byte_bits_left;
463       }
464 
465       element |= (in_byte & kMasks[chunk_bits - 1]) << element_bits_done;
466       in_byte_bits_left -= chunk_bits;
467       in_byte >>= chunk_bits;
468 
469       element_bits_done += chunk_bits;
470     }
471 
472     if (element >= kPrime) {
473       return 0;
474     }
475     out->c[i] = element;
476   }
477 
478   return 1;
479 }
480 
481 // scalar_decode_1 is |scalar_decode| specialised for |bits| == 1.
scalar_decode_1(scalar * out,const uint8_t in[32])482 static void scalar_decode_1(scalar *out, const uint8_t in[32]) {
483   for (int i = 0; i < DEGREE; i += 8) {
484     uint8_t in_byte = *in;
485     in++;
486     for (int j = 0; j < 8; j++) {
487       out->c[i + j] = in_byte & 1;
488       in_byte >>= 1;
489     }
490   }
491 }
492 
493 // Decodes 32*|RANK|*|bits| bytes from |in| into |out|. It returns one on
494 // success or zero if any parsed value is >= |kPrime|.
vector_decode(vector * out,const uint8_t * in,int bits)495 static int vector_decode(vector *out, const uint8_t *in, int bits) {
496   for (int i = 0; i < RANK; i++) {
497     if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8, bits)) {
498       return 0;
499     }
500   }
501   return 1;
502 }
503 
504 // Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
505 // numbers close to each other together. The formula used is
506 // round(2^|bits|/kPrime*x) mod 2^|bits|.
507 // Uses Barrett reduction to achieve constant time. Since we need both the
508 // remainder (for rounding) and the quotient (as the result), we cannot use
509 // |reduce| here, but need to do the Barrett reduction directly.
compress(uint16_t x,int bits)510 static uint16_t compress(uint16_t x, int bits) {
511   uint32_t shifted = (uint32_t)x << bits;
512   uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
513   uint32_t quotient = (uint32_t)(product >> kBarrettShift);
514   uint32_t remainder = shifted - quotient * kPrime;
515 
516   // Adjust the quotient to round correctly:
517   //   0 <= remainder <= kHalfPrime round to 0
518   //   kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
519   //   kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
520   assert(remainder < 2u * kPrime);
521   quotient += 1 & constant_time_lt_w(kHalfPrime, remainder);
522   quotient += 1 & constant_time_lt_w(kPrime + kHalfPrime, remainder);
523   return quotient & ((1 << bits) - 1);
524 }
525 
526 // Decompresses |x| by using an equi-distant representative. The formula is
527 // round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to
528 // implement this logic using only bit operations.
decompress(uint16_t x,int bits)529 static uint16_t decompress(uint16_t x, int bits) {
530   uint32_t product = (uint32_t)x * kPrime;
531   uint32_t power = 1 << bits;
532   // This is |product| % power, since |power| is a power of 2.
533   uint32_t remainder = product & (power - 1);
534   // This is |product| / power, since |power| is a power of 2.
535   uint32_t lower = product >> bits;
536   // The rounding logic works since the first half of numbers mod |power| have a
537   // 0 as first bit, and the second half has a 1 as first bit, since |power| is
538   // a power of 2. As a 12 bit number, |remainder| is always positive, so we
539   // will shift in 0s for a right shift.
540   return lower + (remainder >> (bits - 1));
541 }
542 
scalar_compress(scalar * s,int bits)543 static void scalar_compress(scalar *s, int bits) {
544   for (int i = 0; i < DEGREE; i++) {
545     s->c[i] = compress(s->c[i], bits);
546   }
547 }
548 
scalar_decompress(scalar * s,int bits)549 static void scalar_decompress(scalar *s, int bits) {
550   for (int i = 0; i < DEGREE; i++) {
551     s->c[i] = decompress(s->c[i], bits);
552   }
553 }
554 
vector_compress(vector * a,int bits)555 static void vector_compress(vector *a, int bits) {
556   for (int i = 0; i < RANK; i++) {
557     scalar_compress(&a->v[i], bits);
558   }
559 }
560 
vector_decompress(vector * a,int bits)561 static void vector_decompress(vector *a, int bits) {
562   for (int i = 0; i < RANK; i++) {
563     scalar_decompress(&a->v[i], bits);
564   }
565 }
566 
567 struct public_key {
568   vector t;
569   uint8_t rho[32];
570   uint8_t public_key_hash[32];
571   matrix m;
572 };
573 
public_key_from_external(const struct KYBER_public_key * external)574 static struct public_key *public_key_from_external(
575     const struct KYBER_public_key *external) {
576   static_assert(sizeof(struct KYBER_public_key) >= sizeof(struct public_key),
577                 "Kyber public key is too small");
578   static_assert(alignof(struct KYBER_public_key) >= alignof(struct public_key),
579                 "Kyber public key align incorrect");
580   return (struct public_key *)external;
581 }
582 
583 struct private_key {
584   struct public_key pub;
585   vector s;
586   uint8_t fo_failure_secret[32];
587 };
588 
private_key_from_external(const struct KYBER_private_key * external)589 static struct private_key *private_key_from_external(
590     const struct KYBER_private_key *external) {
591   static_assert(sizeof(struct KYBER_private_key) >= sizeof(struct private_key),
592                 "Kyber private key too small");
593   static_assert(
594       alignof(struct KYBER_private_key) >= alignof(struct private_key),
595       "Kyber private key align incorrect");
596   return (struct private_key *)external;
597 }
598 
599 // Calls |KYBER_generate_key_external_entropy| with random bytes from
600 // |RAND_bytes|.
KYBER_generate_key(uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],struct KYBER_private_key * out_private_key)601 void KYBER_generate_key(uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],
602                         struct KYBER_private_key *out_private_key) {
603   uint8_t entropy[KYBER_GENERATE_KEY_ENTROPY];
604   RAND_bytes(entropy, sizeof(entropy));
605   KYBER_generate_key_external_entropy(out_encoded_public_key, out_private_key,
606                                       entropy);
607 }
608 
kyber_marshal_public_key(CBB * out,const struct public_key * pub)609 static int kyber_marshal_public_key(CBB *out, const struct public_key *pub) {
610   uint8_t *vector_output;
611   if (!CBB_add_space(out, &vector_output, kEncodedVectorSize)) {
612     return 0;
613   }
614   vector_encode(vector_output, &pub->t, kLog2Prime);
615   if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) {
616     return 0;
617   }
618   return 1;
619 }
620 
621 // Algorithms 4 and 7 of the Kyber spec. Algorithms are combined since key
622 // generation is not part of the FO transform, and the spec uses Algorithm 7 to
623 // specify the actual key format.
KYBER_generate_key_external_entropy(uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],struct KYBER_private_key * out_private_key,const uint8_t entropy[KYBER_GENERATE_KEY_ENTROPY])624 void KYBER_generate_key_external_entropy(
625     uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],
626     struct KYBER_private_key *out_private_key,
627     const uint8_t entropy[KYBER_GENERATE_KEY_ENTROPY]) {
628   struct private_key *priv = private_key_from_external(out_private_key);
629   uint8_t hashed[64];
630   hash_g(hashed, entropy, 32);
631   const uint8_t *const rho = hashed;
632   const uint8_t *const sigma = hashed + 32;
633   OPENSSL_memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho));
634   matrix_expand(&priv->pub.m, rho);
635   uint8_t counter = 0;
636   vector_generate_secret_eta_2(&priv->s, &counter, sigma);
637   vector_ntt(&priv->s);
638   vector error;
639   vector_generate_secret_eta_2(&error, &counter, sigma);
640   vector_ntt(&error);
641   matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s);
642   vector_add(&priv->pub.t, &error);
643 
644   CBB cbb;
645   CBB_init_fixed(&cbb, out_encoded_public_key, KYBER_PUBLIC_KEY_BYTES);
646   if (!kyber_marshal_public_key(&cbb, &priv->pub)) {
647     abort();
648   }
649 
650   hash_h(priv->pub.public_key_hash, out_encoded_public_key,
651          KYBER_PUBLIC_KEY_BYTES);
652   OPENSSL_memcpy(priv->fo_failure_secret, entropy + 32, 32);
653 }
654 
KYBER_public_from_private(struct KYBER_public_key * out_public_key,const struct KYBER_private_key * private_key)655 void KYBER_public_from_private(struct KYBER_public_key *out_public_key,
656                                const struct KYBER_private_key *private_key) {
657   struct public_key *const pub = public_key_from_external(out_public_key);
658   const struct private_key *const priv = private_key_from_external(private_key);
659   *pub = priv->pub;
660 }
661 
662 // Algorithm 5 of the Kyber spec. Encrypts a message with given randomness to
663 // the ciphertext in |out|. Without applying the Fujisaki-Okamoto transform this
664 // would not result in a CCA secure scheme, since lattice schemes are vulnerable
665 // to decryption failure oracles.
encrypt_cpa(uint8_t out[KYBER_CIPHERTEXT_BYTES],const struct public_key * pub,const uint8_t message[32],const uint8_t randomness[32])666 static void encrypt_cpa(uint8_t out[KYBER_CIPHERTEXT_BYTES],
667                         const struct public_key *pub, const uint8_t message[32],
668                         const uint8_t randomness[32]) {
669   uint8_t counter = 0;
670   vector secret;
671   vector_generate_secret_eta_2(&secret, &counter, randomness);
672   vector_ntt(&secret);
673   vector error;
674   vector_generate_secret_eta_2(&error, &counter, randomness);
675   uint8_t input[33];
676   OPENSSL_memcpy(input, randomness, 32);
677   input[32] = counter;
678   scalar scalar_error;
679   scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error, input);
680   vector u;
681   matrix_mult(&u, &pub->m, &secret);
682   vector_inverse_ntt(&u);
683   vector_add(&u, &error);
684   scalar v;
685   scalar_inner_product(&v, &pub->t, &secret);
686   scalar_inverse_ntt(&v);
687   scalar_add(&v, &scalar_error);
688   scalar expanded_message;
689   scalar_decode_1(&expanded_message, message);
690   scalar_decompress(&expanded_message, 1);
691   scalar_add(&v, &expanded_message);
692   vector_compress(&u, kDU);
693   vector_encode(out, &u, kDU);
694   scalar_compress(&v, kDV);
695   scalar_encode(out + kCompressedVectorSize, &v, kDV);
696 }
697 
698 // Calls KYBER_encap_external_entropy| with random bytes from |RAND_bytes|
KYBER_encap(uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES],uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],const struct KYBER_public_key * public_key)699 void KYBER_encap(uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES],
700                  uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],
701                  const struct KYBER_public_key *public_key) {
702   uint8_t entropy[KYBER_ENCAP_ENTROPY];
703   RAND_bytes(entropy, KYBER_ENCAP_ENTROPY);
704   KYBER_encap_external_entropy(out_ciphertext, out_shared_secret, public_key,
705                                entropy);
706 }
707 
708 // Algorithm 8 of the Kyber spec, safe for line 2 of the spec. The spec there
709 // hashes the output of the system's random number generator, since the FO
710 // transform will reveal it to the decrypting party. There is no reason to do
711 // this when a secure random number generator is used. When an insecure random
712 // number generator is used, the caller should switch to a secure one before
713 // calling this method.
KYBER_encap_external_entropy(uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES],uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],const struct KYBER_public_key * public_key,const uint8_t entropy[KYBER_ENCAP_ENTROPY])714 void KYBER_encap_external_entropy(
715     uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES],
716     uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],
717     const struct KYBER_public_key *public_key,
718     const uint8_t entropy[KYBER_ENCAP_ENTROPY]) {
719   const struct public_key *pub = public_key_from_external(public_key);
720   uint8_t input[64];
721   OPENSSL_memcpy(input, entropy, KYBER_ENCAP_ENTROPY);
722   OPENSSL_memcpy(input + KYBER_ENCAP_ENTROPY, pub->public_key_hash,
723                  sizeof(input) - KYBER_ENCAP_ENTROPY);
724   uint8_t prekey_and_randomness[64];
725   hash_g(prekey_and_randomness, input, sizeof(input));
726   encrypt_cpa(out_ciphertext, pub, entropy, prekey_and_randomness + 32);
727   hash_h(prekey_and_randomness + 32, out_ciphertext, KYBER_CIPHERTEXT_BYTES);
728   kdf(out_shared_secret, KYBER_SHARED_SECRET_BYTES, prekey_and_randomness,
729       sizeof(prekey_and_randomness));
730 }
731 
732 // Algorithm 6 of the Kyber spec.
decrypt_cpa(uint8_t out[32],const struct private_key * priv,const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES])733 static void decrypt_cpa(uint8_t out[32], const struct private_key *priv,
734                         const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES]) {
735   vector u;
736   vector_decode(&u, ciphertext, kDU);
737   vector_decompress(&u, kDU);
738   vector_ntt(&u);
739   scalar v;
740   scalar_decode(&v, ciphertext + kCompressedVectorSize, kDV);
741   scalar_decompress(&v, kDV);
742   scalar mask;
743   scalar_inner_product(&mask, &priv->s, &u);
744   scalar_inverse_ntt(&mask);
745   scalar_sub(&v, &mask);
746   scalar_compress(&v, 1);
747   scalar_encode_1(out, &v);
748 }
749 
750 // Algorithm 9 of the Kyber spec, performing the FO transform by running
751 // encrypt_cpa on the decrypted message. The spec does not allow the decryption
752 // failure to be passed on to the caller, and instead returns a result that is
753 // deterministic but unpredictable to anyone without knowledge of the private
754 // key.
KYBER_decap(uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES],const struct KYBER_private_key * private_key)755 void KYBER_decap(uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],
756                  const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES],
757                  const struct KYBER_private_key *private_key) {
758   const struct private_key *priv = private_key_from_external(private_key);
759   uint8_t decrypted[64];
760   decrypt_cpa(decrypted, priv, ciphertext);
761   OPENSSL_memcpy(decrypted + 32, priv->pub.public_key_hash,
762                  sizeof(decrypted) - 32);
763   uint8_t prekey_and_randomness[64];
764   hash_g(prekey_and_randomness, decrypted, sizeof(decrypted));
765   uint8_t expected_ciphertext[KYBER_CIPHERTEXT_BYTES];
766   encrypt_cpa(expected_ciphertext, &priv->pub, decrypted,
767               prekey_and_randomness + 32);
768   uint8_t mask =
769       constant_time_eq_int_8(CRYPTO_memcmp(ciphertext, expected_ciphertext,
770                                            sizeof(expected_ciphertext)),
771                              0);
772   uint8_t input[64];
773   for (int i = 0; i < 32; i++) {
774     input[i] = constant_time_select_8(mask, prekey_and_randomness[i],
775                                       priv->fo_failure_secret[i]);
776   }
777   hash_h(input + 32, ciphertext, KYBER_CIPHERTEXT_BYTES);
778   kdf(out_shared_secret, KYBER_SHARED_SECRET_BYTES, input, sizeof(input));
779 }
780 
KYBER_marshal_public_key(CBB * out,const struct KYBER_public_key * public_key)781 int KYBER_marshal_public_key(CBB *out,
782                              const struct KYBER_public_key *public_key) {
783   return kyber_marshal_public_key(out, public_key_from_external(public_key));
784 }
785 
786 // kyber_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate
787 // the value of |pub->public_key_hash|.
kyber_parse_public_key_no_hash(struct public_key * pub,CBS * in)788 static int kyber_parse_public_key_no_hash(struct public_key *pub, CBS *in) {
789   CBS t_bytes;
790   if (!CBS_get_bytes(in, &t_bytes, kEncodedVectorSize) ||
791       !vector_decode(&pub->t, CBS_data(&t_bytes), kLog2Prime) ||
792       !CBS_copy_bytes(in, pub->rho, sizeof(pub->rho))) {
793     return 0;
794   }
795   matrix_expand(&pub->m, pub->rho);
796   return 1;
797 }
798 
KYBER_parse_public_key(struct KYBER_public_key * public_key,CBS * in)799 int KYBER_parse_public_key(struct KYBER_public_key *public_key, CBS *in) {
800   struct public_key *pub = public_key_from_external(public_key);
801   CBS orig_in = *in;
802   if (!kyber_parse_public_key_no_hash(pub, in) ||  //
803       CBS_len(in) != 0) {
804     return 0;
805   }
806   hash_h(pub->public_key_hash, CBS_data(&orig_in), CBS_len(&orig_in));
807   return 1;
808 }
809 
KYBER_marshal_private_key(CBB * out,const struct KYBER_private_key * private_key)810 int KYBER_marshal_private_key(CBB *out,
811                               const struct KYBER_private_key *private_key) {
812   const struct private_key *const priv = private_key_from_external(private_key);
813   uint8_t *s_output;
814   if (!CBB_add_space(out, &s_output, kEncodedVectorSize)) {
815     return 0;
816   }
817   vector_encode(s_output, &priv->s, kLog2Prime);
818   if (!kyber_marshal_public_key(out, &priv->pub) ||
819       !CBB_add_bytes(out, priv->pub.public_key_hash,
820                      sizeof(priv->pub.public_key_hash)) ||
821       !CBB_add_bytes(out, priv->fo_failure_secret,
822                      sizeof(priv->fo_failure_secret))) {
823     return 0;
824   }
825   return 1;
826 }
827 
KYBER_parse_private_key(struct KYBER_private_key * out_private_key,CBS * in)828 int KYBER_parse_private_key(struct KYBER_private_key *out_private_key,
829                             CBS *in) {
830   struct private_key *const priv = private_key_from_external(out_private_key);
831 
832   CBS s_bytes;
833   if (!CBS_get_bytes(in, &s_bytes, kEncodedVectorSize) ||
834       !vector_decode(&priv->s, CBS_data(&s_bytes), kLog2Prime) ||
835       !kyber_parse_public_key_no_hash(&priv->pub, in) ||
836       !CBS_copy_bytes(in, priv->pub.public_key_hash,
837                       sizeof(priv->pub.public_key_hash)) ||
838       !CBS_copy_bytes(in, priv->fo_failure_secret,
839                       sizeof(priv->fo_failure_secret)) ||
840       CBS_len(in) != 0) {
841     return 0;
842   }
843   return 1;
844 }
845