1 /* Copyright (c) 2023, Google Inc.
2 *
3 * Permission to use, copy, modify, and/or distribute this software for any
4 * purpose with or without fee is hereby granted, provided that the above
5 * copyright notice and this permission notice appear in all copies.
6 *
7 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
14
15 #include <openssl/experimental/kyber.h>
16
17 #include <assert.h>
18 #include <stdlib.h>
19
20 #include <openssl/bytestring.h>
21 #include <openssl/rand.h>
22
23 #include "../internal.h"
24 #include "../keccak/internal.h"
25 #include "./internal.h"
26
27
28 // See
29 // https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
30
prf(uint8_t * out,size_t out_len,const uint8_t in[33])31 static void prf(uint8_t *out, size_t out_len, const uint8_t in[33]) {
32 BORINGSSL_keccak(out, out_len, in, 33, boringssl_shake256);
33 }
34
hash_h(uint8_t out[32],const uint8_t * in,size_t len)35 static void hash_h(uint8_t out[32], const uint8_t *in, size_t len) {
36 BORINGSSL_keccak(out, 32, in, len, boringssl_sha3_256);
37 }
38
hash_g(uint8_t out[64],const uint8_t * in,size_t len)39 static void hash_g(uint8_t out[64], const uint8_t *in, size_t len) {
40 BORINGSSL_keccak(out, 64, in, len, boringssl_sha3_512);
41 }
42
kdf(uint8_t * out,size_t out_len,const uint8_t * in,size_t len)43 static void kdf(uint8_t *out, size_t out_len, const uint8_t *in, size_t len) {
44 BORINGSSL_keccak(out, out_len, in, len, boringssl_shake256);
45 }
46
47 #define DEGREE 256
48 #define RANK 3
49
50 static const size_t kBarrettMultiplier = 5039;
51 static const unsigned kBarrettShift = 24;
52 static const uint16_t kPrime = 3329;
53 static const int kLog2Prime = 12;
54 static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2;
55 static const int kDU = 10;
56 static const int kDV = 4;
57 // kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th
58 // root of unity.
59 static const uint16_t kInverseDegree = 3303;
60 static const size_t kEncodedVectorSize =
61 (/*kLog2Prime=*/12 * DEGREE / 8) * RANK;
62 static const size_t kCompressedVectorSize = /*kDU=*/10 * RANK * DEGREE / 8;
63
64 typedef struct scalar {
65 // On every function entry and exit, 0 <= c < kPrime.
66 uint16_t c[DEGREE];
67 } scalar;
68
69 typedef struct vector {
70 scalar v[RANK];
71 } vector;
72
73 typedef struct matrix {
74 scalar v[RANK][RANK];
75 } matrix;
76
77 // This bit of Python will be referenced in some of the following comments:
78 //
79 // p = 3329
80 //
81 // def bitreverse(i):
82 // ret = 0
83 // for n in range(7):
84 // bit = i & 1
85 // ret <<= 1
86 // ret |= bit
87 // i >>= 1
88 // return ret
89
90 // kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
91 static const uint16_t kNTTRoots[128] = {
92 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797,
93 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333,
94 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756,
95 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915,
96 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
97 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100,
98 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789,
99 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641,
100 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757,
101 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886,
102 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154,
103 };
104
105 // kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
106 static const uint16_t kInverseNTTRoots[128] = {
107 1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543,
108 2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447, 2794, 1235, 1903,
109 1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855,
110 2681, 1848, 712, 682, 927, 1795, 461, 1891, 2877, 2522, 1894, 1010,
111 1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132,
112 1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607,
113 2443, 554, 1179, 2186, 2303, 2926, 2237, 525, 735, 863, 2768, 1230,
114 2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745,
115 2688, 3061, 992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482,
116 1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096, 48, 667, 1920,
117 2229, 1041, 2606, 1692, 680, 2746, 568, 3312,
118 };
119
120 // kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
121 static const uint16_t kModRoots[128] = {
122 17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606,
123 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096,
124 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678,
125 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642,
126 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992,
127 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
128 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010,
129 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735,
130 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179,
131 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
132 2110, 1219, 2935, 394, 885, 2444, 2154, 1175,
133 };
134
135 // reduce_once reduces 0 <= x < 2*kPrime, mod kPrime.
reduce_once(uint16_t x)136 static uint16_t reduce_once(uint16_t x) {
137 assert(x < 2 * kPrime);
138 const uint16_t subtracted = x - kPrime;
139 uint16_t mask = 0u - (subtracted >> 15);
140 // On Aarch64, omitting a |value_barrier_u16| results in a 2x speedup of Kyber
141 // overall and Clang still produces constant-time code using `csel`. On other
142 // platforms & compilers on godbolt that we care about, this code also
143 // produces constant-time output.
144 return (mask & x) | (~mask & subtracted);
145 }
146
147 // constant time reduce x mod kPrime using Barrett reduction. x must be less
148 // than kPrime + 2×kPrime².
reduce(uint32_t x)149 static uint16_t reduce(uint32_t x) {
150 assert(x < kPrime + 2u * kPrime * kPrime);
151 uint64_t product = (uint64_t)x * kBarrettMultiplier;
152 uint32_t quotient = (uint32_t)(product >> kBarrettShift);
153 uint32_t remainder = x - quotient * kPrime;
154 return reduce_once(remainder);
155 }
156
scalar_zero(scalar * out)157 static void scalar_zero(scalar *out) { OPENSSL_memset(out, 0, sizeof(*out)); }
158
vector_zero(vector * out)159 static void vector_zero(vector *out) { OPENSSL_memset(out, 0, sizeof(*out)); }
160
161 // In place number theoretic transform of a given scalar.
162 // Note that Kyber's kPrime 3329 does not have a 512th root of unity, so this
163 // transform leaves off the last iteration of the usual FFT code, with the 128
164 // relevant roots of unity being stored in |kNTTRoots|. This means the output
165 // should be seen as 128 elements in GF(3329^2), with the coefficients of the
166 // elements being consecutive entries in |s->c|.
scalar_ntt(scalar * s)167 static void scalar_ntt(scalar *s) {
168 int offset = DEGREE;
169 // `int` is used here because using `size_t` throughout caused a ~5% slowdown
170 // with Clang 14 on Aarch64.
171 for (int step = 1; step < DEGREE / 2; step <<= 1) {
172 offset >>= 1;
173 int k = 0;
174 for (int i = 0; i < step; i++) {
175 const uint32_t step_root = kNTTRoots[i + step];
176 for (int j = k; j < k + offset; j++) {
177 uint16_t odd = reduce(step_root * s->c[j + offset]);
178 uint16_t even = s->c[j];
179 s->c[j] = reduce_once(odd + even);
180 s->c[j + offset] = reduce_once(even - odd + kPrime);
181 }
182 k += 2 * offset;
183 }
184 }
185 }
186
vector_ntt(vector * a)187 static void vector_ntt(vector *a) {
188 for (int i = 0; i < RANK; i++) {
189 scalar_ntt(&a->v[i]);
190 }
191 }
192
193 // In place inverse number theoretic transform of a given scalar, with pairs of
194 // entries of s->v being interpreted as elements of GF(3329^2). Just as with the
195 // number theoretic transform, this leaves off the first step of the normal iFFT
196 // to account for the fact that 3329 does not have a 512th root of unity, using
197 // the precomputed 128 roots of unity stored in |kInverseNTTRoots|.
scalar_inverse_ntt(scalar * s)198 static void scalar_inverse_ntt(scalar *s) {
199 int step = DEGREE / 2;
200 // `int` is used here because using `size_t` throughout caused a ~5% slowdown
201 // with Clang 14 on Aarch64.
202 for (int offset = 2; offset < DEGREE; offset <<= 1) {
203 step >>= 1;
204 int k = 0;
205 for (int i = 0; i < step; i++) {
206 uint32_t step_root = kInverseNTTRoots[i + step];
207 for (int j = k; j < k + offset; j++) {
208 uint16_t odd = s->c[j + offset];
209 uint16_t even = s->c[j];
210 s->c[j] = reduce_once(odd + even);
211 s->c[j + offset] = reduce(step_root * (even - odd + kPrime));
212 }
213 k += 2 * offset;
214 }
215 }
216 for (int i = 0; i < DEGREE; i++) {
217 s->c[i] = reduce(s->c[i] * kInverseDegree);
218 }
219 }
220
vector_inverse_ntt(vector * a)221 static void vector_inverse_ntt(vector *a) {
222 for (int i = 0; i < RANK; i++) {
223 scalar_inverse_ntt(&a->v[i]);
224 }
225 }
226
scalar_add(scalar * lhs,const scalar * rhs)227 static void scalar_add(scalar *lhs, const scalar *rhs) {
228 for (int i = 0; i < DEGREE; i++) {
229 lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
230 }
231 }
232
scalar_sub(scalar * lhs,const scalar * rhs)233 static void scalar_sub(scalar *lhs, const scalar *rhs) {
234 for (int i = 0; i < DEGREE; i++) {
235 lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
236 }
237 }
238
239 // Multiplying two scalars in the number theoretically transformed state. Since
240 // 3329 does not have a 512th root of unity, this means we have to interpret
241 // the 2*ith and (2*i+1)th entries of the scalar as elements of GF(3329)[X]/(X^2
242 // - 17^(2*bitreverse(i)+1)) The value of 17^(2*bitreverse(i)+1) mod 3329 is
243 // stored in the precomputed |kModRoots| table. Note that our Barrett transform
244 // only allows us to multipy two reduced numbers together, so we need some
245 // intermediate reduction steps, even if an uint64_t could hold 3 multiplied
246 // numbers.
scalar_mult(scalar * out,const scalar * lhs,const scalar * rhs)247 static void scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs) {
248 for (int i = 0; i < DEGREE / 2; i++) {
249 uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i];
250 uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i + 1];
251 uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1];
252 uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i];
253 out->c[2 * i] =
254 reduce(real_real + (uint32_t)reduce(img_img) * kModRoots[i]);
255 out->c[2 * i + 1] = reduce(img_real + real_img);
256 }
257 }
258
vector_add(vector * lhs,const vector * rhs)259 static void vector_add(vector *lhs, const vector *rhs) {
260 for (int i = 0; i < RANK; i++) {
261 scalar_add(&lhs->v[i], &rhs->v[i]);
262 }
263 }
264
matrix_mult(vector * out,const matrix * m,const vector * a)265 static void matrix_mult(vector *out, const matrix *m, const vector *a) {
266 vector_zero(out);
267 for (int i = 0; i < RANK; i++) {
268 for (int j = 0; j < RANK; j++) {
269 scalar product;
270 scalar_mult(&product, &m->v[i][j], &a->v[j]);
271 scalar_add(&out->v[i], &product);
272 }
273 }
274 }
275
matrix_mult_transpose(vector * out,const matrix * m,const vector * a)276 static void matrix_mult_transpose(vector *out, const matrix *m,
277 const vector *a) {
278 vector_zero(out);
279 for (int i = 0; i < RANK; i++) {
280 for (int j = 0; j < RANK; j++) {
281 scalar product;
282 scalar_mult(&product, &m->v[j][i], &a->v[j]);
283 scalar_add(&out->v[i], &product);
284 }
285 }
286 }
287
scalar_inner_product(scalar * out,const vector * lhs,const vector * rhs)288 static void scalar_inner_product(scalar *out, const vector *lhs,
289 const vector *rhs) {
290 scalar_zero(out);
291 for (int i = 0; i < RANK; i++) {
292 scalar product;
293 scalar_mult(&product, &lhs->v[i], &rhs->v[i]);
294 scalar_add(out, &product);
295 }
296 }
297
298 // Algorithm 1 of the Kyber spec. Rejection samples a Keccak stream to get
299 // uniformly distributed elements. This is used for matrix expansion and only
300 // operates on public inputs.
scalar_from_keccak_vartime(scalar * out,struct BORINGSSL_keccak_st * keccak_ctx)301 static void scalar_from_keccak_vartime(scalar *out,
302 struct BORINGSSL_keccak_st *keccak_ctx) {
303 assert(keccak_ctx->squeeze_offset == 0);
304 assert(keccak_ctx->rate_bytes == 168);
305 static_assert(168 % 3 == 0, "block and coefficient boundaries do not align");
306
307 int done = 0;
308 while (done < DEGREE) {
309 uint8_t block[168];
310 BORINGSSL_keccak_squeeze(keccak_ctx, block, sizeof(block));
311 for (size_t i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
312 uint16_t d1 = block[i] + 256 * (block[i + 1] % 16);
313 uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2];
314 if (d1 < kPrime) {
315 out->c[done++] = d1;
316 }
317 if (d2 < kPrime && done < DEGREE) {
318 out->c[done++] = d2;
319 }
320 }
321 }
322 }
323
324 // Algorithm 2 of the Kyber spec, with eta fixed to two and the PRF call
325 // included. Creates binominally distributed elements by sampling 2*|eta| bits,
326 // and setting the coefficient to the count of the first bits minus the count of
327 // the second bits, resulting in a centered binomial distribution. Since eta is
328 // two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
329 // and 0 with probability 3/8.
scalar_centered_binomial_distribution_eta_2_with_prf(scalar * out,const uint8_t input[33])330 static void scalar_centered_binomial_distribution_eta_2_with_prf(
331 scalar *out, const uint8_t input[33]) {
332 uint8_t entropy[128];
333 static_assert(sizeof(entropy) == 2 * /*kEta=*/2 * DEGREE / 8, "");
334 prf(entropy, sizeof(entropy), input);
335
336 for (int i = 0; i < DEGREE; i += 2) {
337 uint8_t byte = entropy[i / 2];
338
339 uint16_t value = kPrime;
340 value += (byte & 1) + ((byte >> 1) & 1);
341 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
342 out->c[i] = reduce_once(value);
343
344 byte >>= 4;
345 value = kPrime;
346 value += (byte & 1) + ((byte >> 1) & 1);
347 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
348 out->c[i + 1] = reduce_once(value);
349 }
350 }
351
352 // Generates a secret vector by using
353 // |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed
354 // appending and incrementing |counter| for entry of the vector.
vector_generate_secret_eta_2(vector * out,uint8_t * counter,const uint8_t seed[32])355 static void vector_generate_secret_eta_2(vector *out, uint8_t *counter,
356 const uint8_t seed[32]) {
357 uint8_t input[33];
358 OPENSSL_memcpy(input, seed, 32);
359 for (int i = 0; i < RANK; i++) {
360 input[32] = (*counter)++;
361 scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i], input);
362 }
363 }
364
365 // Expands the matrix of a seed for key generation and for encaps-CPA.
matrix_expand(matrix * out,const uint8_t rho[32])366 static void matrix_expand(matrix *out, const uint8_t rho[32]) {
367 uint8_t input[34];
368 OPENSSL_memcpy(input, rho, 32);
369 for (int i = 0; i < RANK; i++) {
370 for (int j = 0; j < RANK; j++) {
371 input[32] = i;
372 input[33] = j;
373 struct BORINGSSL_keccak_st keccak_ctx;
374 BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake128);
375 BORINGSSL_keccak_absorb(&keccak_ctx, input, sizeof(input));
376 scalar_from_keccak_vartime(&out->v[i][j], &keccak_ctx);
377 }
378 }
379 }
380
381 static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
382 0x1f, 0x3f, 0x7f, 0xff};
383
scalar_encode(uint8_t * out,const scalar * s,int bits)384 static void scalar_encode(uint8_t *out, const scalar *s, int bits) {
385 assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
386
387 uint8_t out_byte = 0;
388 int out_byte_bits = 0;
389
390 for (int i = 0; i < DEGREE; i++) {
391 uint16_t element = s->c[i];
392 int element_bits_done = 0;
393
394 while (element_bits_done < bits) {
395 int chunk_bits = bits - element_bits_done;
396 int out_bits_remaining = 8 - out_byte_bits;
397 if (chunk_bits >= out_bits_remaining) {
398 chunk_bits = out_bits_remaining;
399 out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
400 *out = out_byte;
401 out++;
402 out_byte_bits = 0;
403 out_byte = 0;
404 } else {
405 out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
406 out_byte_bits += chunk_bits;
407 }
408
409 element_bits_done += chunk_bits;
410 element >>= chunk_bits;
411 }
412 }
413
414 if (out_byte_bits > 0) {
415 *out = out_byte;
416 }
417 }
418
419 // scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
scalar_encode_1(uint8_t out[32],const scalar * s)420 static void scalar_encode_1(uint8_t out[32], const scalar *s) {
421 for (int i = 0; i < DEGREE; i += 8) {
422 uint8_t out_byte = 0;
423 for (int j = 0; j < 8; j++) {
424 out_byte |= (s->c[i + j] & 1) << j;
425 }
426 *out = out_byte;
427 out++;
428 }
429 }
430
431 // Encodes an entire vector into 32*|RANK|*|bits| bytes. Note that since 256
432 // (DEGREE) is divisible by 8, the individual vector entries will always fill a
433 // whole number of bytes, so we do not need to worry about bit packing here.
vector_encode(uint8_t * out,const vector * a,int bits)434 static void vector_encode(uint8_t *out, const vector *a, int bits) {
435 for (int i = 0; i < RANK; i++) {
436 scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits);
437 }
438 }
439
440 // scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
441 // |out|. It returns one on success and zero if any parsed value is >=
442 // |kPrime|.
scalar_decode(scalar * out,const uint8_t * in,int bits)443 static int scalar_decode(scalar *out, const uint8_t *in, int bits) {
444 assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
445
446 uint8_t in_byte = 0;
447 int in_byte_bits_left = 0;
448
449 for (int i = 0; i < DEGREE; i++) {
450 uint16_t element = 0;
451 int element_bits_done = 0;
452
453 while (element_bits_done < bits) {
454 if (in_byte_bits_left == 0) {
455 in_byte = *in;
456 in++;
457 in_byte_bits_left = 8;
458 }
459
460 int chunk_bits = bits - element_bits_done;
461 if (chunk_bits > in_byte_bits_left) {
462 chunk_bits = in_byte_bits_left;
463 }
464
465 element |= (in_byte & kMasks[chunk_bits - 1]) << element_bits_done;
466 in_byte_bits_left -= chunk_bits;
467 in_byte >>= chunk_bits;
468
469 element_bits_done += chunk_bits;
470 }
471
472 if (element >= kPrime) {
473 return 0;
474 }
475 out->c[i] = element;
476 }
477
478 return 1;
479 }
480
481 // scalar_decode_1 is |scalar_decode| specialised for |bits| == 1.
scalar_decode_1(scalar * out,const uint8_t in[32])482 static void scalar_decode_1(scalar *out, const uint8_t in[32]) {
483 for (int i = 0; i < DEGREE; i += 8) {
484 uint8_t in_byte = *in;
485 in++;
486 for (int j = 0; j < 8; j++) {
487 out->c[i + j] = in_byte & 1;
488 in_byte >>= 1;
489 }
490 }
491 }
492
493 // Decodes 32*|RANK|*|bits| bytes from |in| into |out|. It returns one on
494 // success or zero if any parsed value is >= |kPrime|.
vector_decode(vector * out,const uint8_t * in,int bits)495 static int vector_decode(vector *out, const uint8_t *in, int bits) {
496 for (int i = 0; i < RANK; i++) {
497 if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8, bits)) {
498 return 0;
499 }
500 }
501 return 1;
502 }
503
504 // Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
505 // numbers close to each other together. The formula used is
506 // round(2^|bits|/kPrime*x) mod 2^|bits|.
507 // Uses Barrett reduction to achieve constant time. Since we need both the
508 // remainder (for rounding) and the quotient (as the result), we cannot use
509 // |reduce| here, but need to do the Barrett reduction directly.
compress(uint16_t x,int bits)510 static uint16_t compress(uint16_t x, int bits) {
511 uint32_t shifted = (uint32_t)x << bits;
512 uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
513 uint32_t quotient = (uint32_t)(product >> kBarrettShift);
514 uint32_t remainder = shifted - quotient * kPrime;
515
516 // Adjust the quotient to round correctly:
517 // 0 <= remainder <= kHalfPrime round to 0
518 // kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
519 // kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
520 assert(remainder < 2u * kPrime);
521 quotient += 1 & constant_time_lt_w(kHalfPrime, remainder);
522 quotient += 1 & constant_time_lt_w(kPrime + kHalfPrime, remainder);
523 return quotient & ((1 << bits) - 1);
524 }
525
526 // Decompresses |x| by using an equi-distant representative. The formula is
527 // round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to
528 // implement this logic using only bit operations.
decompress(uint16_t x,int bits)529 static uint16_t decompress(uint16_t x, int bits) {
530 uint32_t product = (uint32_t)x * kPrime;
531 uint32_t power = 1 << bits;
532 // This is |product| % power, since |power| is a power of 2.
533 uint32_t remainder = product & (power - 1);
534 // This is |product| / power, since |power| is a power of 2.
535 uint32_t lower = product >> bits;
536 // The rounding logic works since the first half of numbers mod |power| have a
537 // 0 as first bit, and the second half has a 1 as first bit, since |power| is
538 // a power of 2. As a 12 bit number, |remainder| is always positive, so we
539 // will shift in 0s for a right shift.
540 return lower + (remainder >> (bits - 1));
541 }
542
scalar_compress(scalar * s,int bits)543 static void scalar_compress(scalar *s, int bits) {
544 for (int i = 0; i < DEGREE; i++) {
545 s->c[i] = compress(s->c[i], bits);
546 }
547 }
548
scalar_decompress(scalar * s,int bits)549 static void scalar_decompress(scalar *s, int bits) {
550 for (int i = 0; i < DEGREE; i++) {
551 s->c[i] = decompress(s->c[i], bits);
552 }
553 }
554
vector_compress(vector * a,int bits)555 static void vector_compress(vector *a, int bits) {
556 for (int i = 0; i < RANK; i++) {
557 scalar_compress(&a->v[i], bits);
558 }
559 }
560
vector_decompress(vector * a,int bits)561 static void vector_decompress(vector *a, int bits) {
562 for (int i = 0; i < RANK; i++) {
563 scalar_decompress(&a->v[i], bits);
564 }
565 }
566
567 struct public_key {
568 vector t;
569 uint8_t rho[32];
570 uint8_t public_key_hash[32];
571 matrix m;
572 };
573
public_key_from_external(const struct KYBER_public_key * external)574 static struct public_key *public_key_from_external(
575 const struct KYBER_public_key *external) {
576 static_assert(sizeof(struct KYBER_public_key) >= sizeof(struct public_key),
577 "Kyber public key is too small");
578 static_assert(alignof(struct KYBER_public_key) >= alignof(struct public_key),
579 "Kyber public key align incorrect");
580 return (struct public_key *)external;
581 }
582
583 struct private_key {
584 struct public_key pub;
585 vector s;
586 uint8_t fo_failure_secret[32];
587 };
588
private_key_from_external(const struct KYBER_private_key * external)589 static struct private_key *private_key_from_external(
590 const struct KYBER_private_key *external) {
591 static_assert(sizeof(struct KYBER_private_key) >= sizeof(struct private_key),
592 "Kyber private key too small");
593 static_assert(
594 alignof(struct KYBER_private_key) >= alignof(struct private_key),
595 "Kyber private key align incorrect");
596 return (struct private_key *)external;
597 }
598
599 // Calls |KYBER_generate_key_external_entropy| with random bytes from
600 // |RAND_bytes|.
KYBER_generate_key(uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],struct KYBER_private_key * out_private_key)601 void KYBER_generate_key(uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],
602 struct KYBER_private_key *out_private_key) {
603 uint8_t entropy[KYBER_GENERATE_KEY_ENTROPY];
604 RAND_bytes(entropy, sizeof(entropy));
605 KYBER_generate_key_external_entropy(out_encoded_public_key, out_private_key,
606 entropy);
607 }
608
kyber_marshal_public_key(CBB * out,const struct public_key * pub)609 static int kyber_marshal_public_key(CBB *out, const struct public_key *pub) {
610 uint8_t *vector_output;
611 if (!CBB_add_space(out, &vector_output, kEncodedVectorSize)) {
612 return 0;
613 }
614 vector_encode(vector_output, &pub->t, kLog2Prime);
615 if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) {
616 return 0;
617 }
618 return 1;
619 }
620
621 // Algorithms 4 and 7 of the Kyber spec. Algorithms are combined since key
622 // generation is not part of the FO transform, and the spec uses Algorithm 7 to
623 // specify the actual key format.
KYBER_generate_key_external_entropy(uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],struct KYBER_private_key * out_private_key,const uint8_t entropy[KYBER_GENERATE_KEY_ENTROPY])624 void KYBER_generate_key_external_entropy(
625 uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],
626 struct KYBER_private_key *out_private_key,
627 const uint8_t entropy[KYBER_GENERATE_KEY_ENTROPY]) {
628 struct private_key *priv = private_key_from_external(out_private_key);
629 uint8_t hashed[64];
630 hash_g(hashed, entropy, 32);
631 const uint8_t *const rho = hashed;
632 const uint8_t *const sigma = hashed + 32;
633 OPENSSL_memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho));
634 matrix_expand(&priv->pub.m, rho);
635 uint8_t counter = 0;
636 vector_generate_secret_eta_2(&priv->s, &counter, sigma);
637 vector_ntt(&priv->s);
638 vector error;
639 vector_generate_secret_eta_2(&error, &counter, sigma);
640 vector_ntt(&error);
641 matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s);
642 vector_add(&priv->pub.t, &error);
643
644 CBB cbb;
645 CBB_init_fixed(&cbb, out_encoded_public_key, KYBER_PUBLIC_KEY_BYTES);
646 if (!kyber_marshal_public_key(&cbb, &priv->pub)) {
647 abort();
648 }
649
650 hash_h(priv->pub.public_key_hash, out_encoded_public_key,
651 KYBER_PUBLIC_KEY_BYTES);
652 OPENSSL_memcpy(priv->fo_failure_secret, entropy + 32, 32);
653 }
654
KYBER_public_from_private(struct KYBER_public_key * out_public_key,const struct KYBER_private_key * private_key)655 void KYBER_public_from_private(struct KYBER_public_key *out_public_key,
656 const struct KYBER_private_key *private_key) {
657 struct public_key *const pub = public_key_from_external(out_public_key);
658 const struct private_key *const priv = private_key_from_external(private_key);
659 *pub = priv->pub;
660 }
661
662 // Algorithm 5 of the Kyber spec. Encrypts a message with given randomness to
663 // the ciphertext in |out|. Without applying the Fujisaki-Okamoto transform this
664 // would not result in a CCA secure scheme, since lattice schemes are vulnerable
665 // to decryption failure oracles.
encrypt_cpa(uint8_t out[KYBER_CIPHERTEXT_BYTES],const struct public_key * pub,const uint8_t message[32],const uint8_t randomness[32])666 static void encrypt_cpa(uint8_t out[KYBER_CIPHERTEXT_BYTES],
667 const struct public_key *pub, const uint8_t message[32],
668 const uint8_t randomness[32]) {
669 uint8_t counter = 0;
670 vector secret;
671 vector_generate_secret_eta_2(&secret, &counter, randomness);
672 vector_ntt(&secret);
673 vector error;
674 vector_generate_secret_eta_2(&error, &counter, randomness);
675 uint8_t input[33];
676 OPENSSL_memcpy(input, randomness, 32);
677 input[32] = counter;
678 scalar scalar_error;
679 scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error, input);
680 vector u;
681 matrix_mult(&u, &pub->m, &secret);
682 vector_inverse_ntt(&u);
683 vector_add(&u, &error);
684 scalar v;
685 scalar_inner_product(&v, &pub->t, &secret);
686 scalar_inverse_ntt(&v);
687 scalar_add(&v, &scalar_error);
688 scalar expanded_message;
689 scalar_decode_1(&expanded_message, message);
690 scalar_decompress(&expanded_message, 1);
691 scalar_add(&v, &expanded_message);
692 vector_compress(&u, kDU);
693 vector_encode(out, &u, kDU);
694 scalar_compress(&v, kDV);
695 scalar_encode(out + kCompressedVectorSize, &v, kDV);
696 }
697
698 // Calls KYBER_encap_external_entropy| with random bytes from |RAND_bytes|
KYBER_encap(uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES],uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],const struct KYBER_public_key * public_key)699 void KYBER_encap(uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES],
700 uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],
701 const struct KYBER_public_key *public_key) {
702 uint8_t entropy[KYBER_ENCAP_ENTROPY];
703 RAND_bytes(entropy, KYBER_ENCAP_ENTROPY);
704 KYBER_encap_external_entropy(out_ciphertext, out_shared_secret, public_key,
705 entropy);
706 }
707
708 // Algorithm 8 of the Kyber spec, safe for line 2 of the spec. The spec there
709 // hashes the output of the system's random number generator, since the FO
710 // transform will reveal it to the decrypting party. There is no reason to do
711 // this when a secure random number generator is used. When an insecure random
712 // number generator is used, the caller should switch to a secure one before
713 // calling this method.
KYBER_encap_external_entropy(uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES],uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],const struct KYBER_public_key * public_key,const uint8_t entropy[KYBER_ENCAP_ENTROPY])714 void KYBER_encap_external_entropy(
715 uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES],
716 uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],
717 const struct KYBER_public_key *public_key,
718 const uint8_t entropy[KYBER_ENCAP_ENTROPY]) {
719 const struct public_key *pub = public_key_from_external(public_key);
720 uint8_t input[64];
721 OPENSSL_memcpy(input, entropy, KYBER_ENCAP_ENTROPY);
722 OPENSSL_memcpy(input + KYBER_ENCAP_ENTROPY, pub->public_key_hash,
723 sizeof(input) - KYBER_ENCAP_ENTROPY);
724 uint8_t prekey_and_randomness[64];
725 hash_g(prekey_and_randomness, input, sizeof(input));
726 encrypt_cpa(out_ciphertext, pub, entropy, prekey_and_randomness + 32);
727 hash_h(prekey_and_randomness + 32, out_ciphertext, KYBER_CIPHERTEXT_BYTES);
728 kdf(out_shared_secret, KYBER_SHARED_SECRET_BYTES, prekey_and_randomness,
729 sizeof(prekey_and_randomness));
730 }
731
732 // Algorithm 6 of the Kyber spec.
decrypt_cpa(uint8_t out[32],const struct private_key * priv,const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES])733 static void decrypt_cpa(uint8_t out[32], const struct private_key *priv,
734 const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES]) {
735 vector u;
736 vector_decode(&u, ciphertext, kDU);
737 vector_decompress(&u, kDU);
738 vector_ntt(&u);
739 scalar v;
740 scalar_decode(&v, ciphertext + kCompressedVectorSize, kDV);
741 scalar_decompress(&v, kDV);
742 scalar mask;
743 scalar_inner_product(&mask, &priv->s, &u);
744 scalar_inverse_ntt(&mask);
745 scalar_sub(&v, &mask);
746 scalar_compress(&v, 1);
747 scalar_encode_1(out, &v);
748 }
749
750 // Algorithm 9 of the Kyber spec, performing the FO transform by running
751 // encrypt_cpa on the decrypted message. The spec does not allow the decryption
752 // failure to be passed on to the caller, and instead returns a result that is
753 // deterministic but unpredictable to anyone without knowledge of the private
754 // key.
KYBER_decap(uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES],const struct KYBER_private_key * private_key)755 void KYBER_decap(uint8_t out_shared_secret[KYBER_SHARED_SECRET_BYTES],
756 const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES],
757 const struct KYBER_private_key *private_key) {
758 const struct private_key *priv = private_key_from_external(private_key);
759 uint8_t decrypted[64];
760 decrypt_cpa(decrypted, priv, ciphertext);
761 OPENSSL_memcpy(decrypted + 32, priv->pub.public_key_hash,
762 sizeof(decrypted) - 32);
763 uint8_t prekey_and_randomness[64];
764 hash_g(prekey_and_randomness, decrypted, sizeof(decrypted));
765 uint8_t expected_ciphertext[KYBER_CIPHERTEXT_BYTES];
766 encrypt_cpa(expected_ciphertext, &priv->pub, decrypted,
767 prekey_and_randomness + 32);
768 uint8_t mask =
769 constant_time_eq_int_8(CRYPTO_memcmp(ciphertext, expected_ciphertext,
770 sizeof(expected_ciphertext)),
771 0);
772 uint8_t input[64];
773 for (int i = 0; i < 32; i++) {
774 input[i] = constant_time_select_8(mask, prekey_and_randomness[i],
775 priv->fo_failure_secret[i]);
776 }
777 hash_h(input + 32, ciphertext, KYBER_CIPHERTEXT_BYTES);
778 kdf(out_shared_secret, KYBER_SHARED_SECRET_BYTES, input, sizeof(input));
779 }
780
KYBER_marshal_public_key(CBB * out,const struct KYBER_public_key * public_key)781 int KYBER_marshal_public_key(CBB *out,
782 const struct KYBER_public_key *public_key) {
783 return kyber_marshal_public_key(out, public_key_from_external(public_key));
784 }
785
786 // kyber_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate
787 // the value of |pub->public_key_hash|.
kyber_parse_public_key_no_hash(struct public_key * pub,CBS * in)788 static int kyber_parse_public_key_no_hash(struct public_key *pub, CBS *in) {
789 CBS t_bytes;
790 if (!CBS_get_bytes(in, &t_bytes, kEncodedVectorSize) ||
791 !vector_decode(&pub->t, CBS_data(&t_bytes), kLog2Prime) ||
792 !CBS_copy_bytes(in, pub->rho, sizeof(pub->rho))) {
793 return 0;
794 }
795 matrix_expand(&pub->m, pub->rho);
796 return 1;
797 }
798
KYBER_parse_public_key(struct KYBER_public_key * public_key,CBS * in)799 int KYBER_parse_public_key(struct KYBER_public_key *public_key, CBS *in) {
800 struct public_key *pub = public_key_from_external(public_key);
801 CBS orig_in = *in;
802 if (!kyber_parse_public_key_no_hash(pub, in) || //
803 CBS_len(in) != 0) {
804 return 0;
805 }
806 hash_h(pub->public_key_hash, CBS_data(&orig_in), CBS_len(&orig_in));
807 return 1;
808 }
809
KYBER_marshal_private_key(CBB * out,const struct KYBER_private_key * private_key)810 int KYBER_marshal_private_key(CBB *out,
811 const struct KYBER_private_key *private_key) {
812 const struct private_key *const priv = private_key_from_external(private_key);
813 uint8_t *s_output;
814 if (!CBB_add_space(out, &s_output, kEncodedVectorSize)) {
815 return 0;
816 }
817 vector_encode(s_output, &priv->s, kLog2Prime);
818 if (!kyber_marshal_public_key(out, &priv->pub) ||
819 !CBB_add_bytes(out, priv->pub.public_key_hash,
820 sizeof(priv->pub.public_key_hash)) ||
821 !CBB_add_bytes(out, priv->fo_failure_secret,
822 sizeof(priv->fo_failure_secret))) {
823 return 0;
824 }
825 return 1;
826 }
827
KYBER_parse_private_key(struct KYBER_private_key * out_private_key,CBS * in)828 int KYBER_parse_private_key(struct KYBER_private_key *out_private_key,
829 CBS *in) {
830 struct private_key *const priv = private_key_from_external(out_private_key);
831
832 CBS s_bytes;
833 if (!CBS_get_bytes(in, &s_bytes, kEncodedVectorSize) ||
834 !vector_decode(&priv->s, CBS_data(&s_bytes), kLog2Prime) ||
835 !kyber_parse_public_key_no_hash(&priv->pub, in) ||
836 !CBS_copy_bytes(in, priv->pub.public_key_hash,
837 sizeof(priv->pub.public_key_hash)) ||
838 !CBS_copy_bytes(in, priv->fo_failure_secret,
839 sizeof(priv->fo_failure_secret)) ||
840 CBS_len(in) != 0) {
841 return 0;
842 }
843 return 1;
844 }
845