1 /* Copyright (c) 2023, Google LLC
2 *
3 * Permission to use, copy, modify, and/or distribute this software for any
4 * purpose with or without fee is hereby granted, provided that the above
5 * copyright notice and this permission notice appear in all copies.
6 *
7 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
14
15 #define OPENSSL_UNSTABLE_EXPERIMENTAL_DILITHIUM
16 #include <openssl/experimental/dilithium.h>
17
18 #include <assert.h>
19 #include <stdlib.h>
20
21 #include <openssl/bytestring.h>
22 #include <openssl/rand.h>
23
24 #include "../internal.h"
25 #include "../keccak/internal.h"
26 #include "./internal.h"
27
28 #define DEGREE 256
29 #define K 6
30 #define L 5
31 #define ETA 4
32 #define TAU 49
33 #define BETA 196
34 #define OMEGA 55
35
36 #define RHO_BYTES 32
37 #define SIGMA_BYTES 64
38 #define K_BYTES 32
39 #define TR_BYTES 64
40 #define MU_BYTES 64
41 #define RHO_PRIME_BYTES 64
42 #define LAMBDA_BITS 192
43 #define LAMBDA_BYTES (LAMBDA_BITS / 8)
44
45 // 2^23 - 2^13 + 1
46 static const uint32_t kPrime = 8380417;
47 // Inverse of -kPrime modulo 2^32
48 static const uint32_t kPrimeNegInverse = 4236238847;
49 static const int kDroppedBits = 13;
50 static const uint32_t kHalfPrime = (8380417 - 1) / 2;
51 static const uint32_t kGamma1 = 1 << 19;
52 static const uint32_t kGamma2 = (8380417 - 1) / 32;
53 // 256^-1 mod kPrime, in Montgomery form.
54 static const uint32_t kInverseDegreeMontgomery = 41978;
55
56 typedef struct scalar {
57 uint32_t c[DEGREE];
58 } scalar;
59
60 typedef struct vectork {
61 scalar v[K];
62 } vectork;
63
64 typedef struct vectorl {
65 scalar v[L];
66 } vectorl;
67
68 typedef struct matrix {
69 scalar v[K][L];
70 } matrix;
71
72 /* Arithmetic */
73
74 // This bit of Python will be referenced in some of the following comments:
75 //
76 // q = 8380417
77 // # Inverse of -q modulo 2^32
78 // q_neg_inverse = 4236238847
79 // # 2^64 modulo q
80 // montgomery_square = 2365951
81 //
82 // def bitreverse(i):
83 // ret = 0
84 // for n in range(8):
85 // bit = i & 1
86 // ret <<= 1
87 // ret |= bit
88 // i >>= 1
89 // return ret
90 //
91 // def montgomery_reduce(x):
92 // a = (x * q_neg_inverse) % 2**32
93 // b = x + a * q
94 // assert b & 0xFFFF_FFFF == 0
95 // c = b >> 32
96 // assert c < q
97 // return c
98 //
99 // def montgomery_transform(x):
100 // return montgomery_reduce(x * montgomery_square)
101
102 // kNTTRootsMontgomery = [
103 // montgomery_transform(pow(1753, bitreverse(i), q)) for i in range(256)
104 // ]
105 static const uint32_t kNTTRootsMontgomery[256] = {
106 4193792, 25847, 5771523, 7861508, 237124, 7602457, 7504169, 466468,
107 1826347, 2353451, 8021166, 6288512, 3119733, 5495562, 3111497, 2680103,
108 2725464, 1024112, 7300517, 3585928, 7830929, 7260833, 2619752, 6271868,
109 6262231, 4520680, 6980856, 5102745, 1757237, 8360995, 4010497, 280005,
110 2706023, 95776, 3077325, 3530437, 6718724, 4788269, 5842901, 3915439,
111 4519302, 5336701, 3574422, 5512770, 3539968, 8079950, 2348700, 7841118,
112 6681150, 6736599, 3505694, 4558682, 3507263, 6239768, 6779997, 3699596,
113 811944, 531354, 954230, 3881043, 3900724, 5823537, 2071892, 5582638,
114 4450022, 6851714, 4702672, 5339162, 6927966, 3475950, 2176455, 6795196,
115 7122806, 1939314, 4296819, 7380215, 5190273, 5223087, 4747489, 126922,
116 3412210, 7396998, 2147896, 2715295, 5412772, 4686924, 7969390, 5903370,
117 7709315, 7151892, 8357436, 7072248, 7998430, 1349076, 1852771, 6949987,
118 5037034, 264944, 508951, 3097992, 44288, 7280319, 904516, 3958618,
119 4656075, 8371839, 1653064, 5130689, 2389356, 8169440, 759969, 7063561,
120 189548, 4827145, 3159746, 6529015, 5971092, 8202977, 1315589, 1341330,
121 1285669, 6795489, 7567685, 6940675, 5361315, 4499357, 4751448, 3839961,
122 2091667, 3407706, 2316500, 3817976, 5037939, 2244091, 5933984, 4817955,
123 266997, 2434439, 7144689, 3513181, 4860065, 4621053, 7183191, 5187039,
124 900702, 1859098, 909542, 819034, 495491, 6767243, 8337157, 7857917,
125 7725090, 5257975, 2031748, 3207046, 4823422, 7855319, 7611795, 4784579,
126 342297, 286988, 5942594, 4108315, 3437287, 5038140, 1735879, 203044,
127 2842341, 2691481, 5790267, 1265009, 4055324, 1247620, 2486353, 1595974,
128 4613401, 1250494, 2635921, 4832145, 5386378, 1869119, 1903435, 7329447,
129 7047359, 1237275, 5062207, 6950192, 7929317, 1312455, 3306115, 6417775,
130 7100756, 1917081, 5834105, 7005614, 1500165, 777191, 2235880, 3406031,
131 7838005, 5548557, 6709241, 6533464, 5796124, 4656147, 594136, 4603424,
132 6366809, 2432395, 2454455, 8215696, 1957272, 3369112, 185531, 7173032,
133 5196991, 162844, 1616392, 3014001, 810149, 1652634, 4686184, 6581310,
134 5341501, 3523897, 3866901, 269760, 2213111, 7404533, 1717735, 472078,
135 7953734, 1723600, 6577327, 1910376, 6712985, 7276084, 8119771, 4546524,
136 5441381, 6144432, 7959518, 6094090, 183443, 7403526, 1612842, 4834730,
137 7826001, 3919660, 8332111, 7018208, 3937738, 1400424, 7534263, 1976782};
138
139 // Reduces x mod kPrime in constant time, where 0 <= x < 2*kPrime.
reduce_once(uint32_t x)140 static uint32_t reduce_once(uint32_t x) {
141 declassify_assert(x < 2 * kPrime);
142 // return x < kPrime ? x : x - kPrime;
143 return constant_time_select_int(constant_time_lt_w(x, kPrime), x, x - kPrime);
144 }
145
146 // Returns the absolute value in constant time.
abs_signed(uint32_t x)147 static uint32_t abs_signed(uint32_t x) {
148 // return is_positive(x) ? x : -x;
149 // Note: MSVC doesn't like applying the unary minus operator to unsigned types
150 // (warning C4146), so we write the negation as a bitwise not plus one
151 // (assuming two's complement representation).
152 return constant_time_select_int(constant_time_lt_w(x, 0x80000000), x, ~x + 1);
153 }
154
155 // Returns the absolute value modulo kPrime.
abs_mod_prime(uint32_t x)156 static uint32_t abs_mod_prime(uint32_t x) {
157 declassify_assert(x < kPrime);
158 // return x > kHalfPrime ? kPrime - x : x;
159 return constant_time_select_int(constant_time_lt_w(kHalfPrime, x), kPrime - x,
160 x);
161 }
162
163 // Returns the maximum of two values in constant time.
maximum(uint32_t x,uint32_t y)164 static uint32_t maximum(uint32_t x, uint32_t y) {
165 // return x < y ? y : x;
166 return constant_time_select_int(constant_time_lt_w(x, y), y, x);
167 }
168
scalar_add(scalar * out,const scalar * lhs,const scalar * rhs)169 static void scalar_add(scalar *out, const scalar *lhs, const scalar *rhs) {
170 for (int i = 0; i < DEGREE; i++) {
171 out->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
172 }
173 }
174
scalar_sub(scalar * out,const scalar * lhs,const scalar * rhs)175 static void scalar_sub(scalar *out, const scalar *lhs, const scalar *rhs) {
176 for (int i = 0; i < DEGREE; i++) {
177 out->c[i] = reduce_once(kPrime + lhs->c[i] - rhs->c[i]);
178 }
179 }
180
reduce_montgomery(uint64_t x)181 static uint32_t reduce_montgomery(uint64_t x) {
182 uint64_t a = (uint32_t)x * kPrimeNegInverse;
183 uint64_t b = x + a * kPrime;
184 declassify_assert((b & 0xffffffff) == 0);
185 uint32_t c = b >> 32;
186 return reduce_once(c);
187 }
188
189 // Multiply two scalars in the number theoretically transformed state.
scalar_mult(scalar * out,const scalar * lhs,const scalar * rhs)190 static void scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs) {
191 for (int i = 0; i < DEGREE; i++) {
192 out->c[i] = reduce_montgomery((uint64_t)lhs->c[i] * (uint64_t)rhs->c[i]);
193 }
194 }
195
196 // In place number theoretic transform of a given scalar.
197 //
198 // FIPS 204, Algorithm 35 (`NTT`).
scalar_ntt(scalar * s)199 static void scalar_ntt(scalar *s) {
200 // Step: 1, 2, 4, 8, ..., 128
201 // Offset: 128, 64, 32, 16, ..., 1
202 int offset = DEGREE;
203 for (int step = 1; step < DEGREE; step <<= 1) {
204 offset >>= 1;
205 int k = 0;
206 for (int i = 0; i < step; i++) {
207 assert(k == 2 * offset * i);
208 const uint32_t step_root = kNTTRootsMontgomery[step + i];
209 for (int j = k; j < k + offset; j++) {
210 uint32_t even = s->c[j];
211 uint32_t odd =
212 reduce_montgomery((uint64_t)step_root * (uint64_t)s->c[j + offset]);
213 s->c[j] = reduce_once(odd + even);
214 s->c[j + offset] = reduce_once(kPrime + even - odd);
215 }
216 k += 2 * offset;
217 }
218 }
219 }
220
221 // In place inverse number theoretic transform of a given scalar.
222 //
223 // FIPS 204, Algorithm 36 (`NTT^-1`).
scalar_inverse_ntt(scalar * s)224 static void scalar_inverse_ntt(scalar *s) {
225 // Step: 128, 64, 32, 16, ..., 1
226 // Offset: 1, 2, 4, 8, ..., 128
227 int step = DEGREE;
228 for (int offset = 1; offset < DEGREE; offset <<= 1) {
229 step >>= 1;
230 int k = 0;
231 for (int i = 0; i < step; i++) {
232 assert(k == 2 * offset * i);
233 const uint32_t step_root =
234 kPrime - kNTTRootsMontgomery[step + (step - 1 - i)];
235 for (int j = k; j < k + offset; j++) {
236 uint32_t even = s->c[j];
237 uint32_t odd = s->c[j + offset];
238 s->c[j] = reduce_once(odd + even);
239 s->c[j + offset] = reduce_montgomery((uint64_t)step_root *
240 (uint64_t)(kPrime + even - odd));
241 }
242 k += 2 * offset;
243 }
244 }
245 for (int i = 0; i < DEGREE; i++) {
246 s->c[i] = reduce_montgomery((uint64_t)s->c[i] *
247 (uint64_t)kInverseDegreeMontgomery);
248 }
249 }
250
vectork_zero(vectork * out)251 static void vectork_zero(vectork *out) { OPENSSL_memset(out, 0, sizeof(*out)); }
252
vectork_add(vectork * out,const vectork * lhs,const vectork * rhs)253 static void vectork_add(vectork *out, const vectork *lhs, const vectork *rhs) {
254 for (int i = 0; i < K; i++) {
255 scalar_add(&out->v[i], &lhs->v[i], &rhs->v[i]);
256 }
257 }
258
vectork_sub(vectork * out,const vectork * lhs,const vectork * rhs)259 static void vectork_sub(vectork *out, const vectork *lhs, const vectork *rhs) {
260 for (int i = 0; i < K; i++) {
261 scalar_sub(&out->v[i], &lhs->v[i], &rhs->v[i]);
262 }
263 }
264
vectork_mult_scalar(vectork * out,const vectork * lhs,const scalar * rhs)265 static void vectork_mult_scalar(vectork *out, const vectork *lhs,
266 const scalar *rhs) {
267 for (int i = 0; i < K; i++) {
268 scalar_mult(&out->v[i], &lhs->v[i], rhs);
269 }
270 }
271
vectork_ntt(vectork * a)272 static void vectork_ntt(vectork *a) {
273 for (int i = 0; i < K; i++) {
274 scalar_ntt(&a->v[i]);
275 }
276 }
277
vectork_inverse_ntt(vectork * a)278 static void vectork_inverse_ntt(vectork *a) {
279 for (int i = 0; i < K; i++) {
280 scalar_inverse_ntt(&a->v[i]);
281 }
282 }
283
vectorl_add(vectorl * out,const vectorl * lhs,const vectorl * rhs)284 static void vectorl_add(vectorl *out, const vectorl *lhs, const vectorl *rhs) {
285 for (int i = 0; i < L; i++) {
286 scalar_add(&out->v[i], &lhs->v[i], &rhs->v[i]);
287 }
288 }
289
vectorl_mult_scalar(vectorl * out,const vectorl * lhs,const scalar * rhs)290 static void vectorl_mult_scalar(vectorl *out, const vectorl *lhs,
291 const scalar *rhs) {
292 for (int i = 0; i < L; i++) {
293 scalar_mult(&out->v[i], &lhs->v[i], rhs);
294 }
295 }
296
vectorl_ntt(vectorl * a)297 static void vectorl_ntt(vectorl *a) {
298 for (int i = 0; i < L; i++) {
299 scalar_ntt(&a->v[i]);
300 }
301 }
302
vectorl_inverse_ntt(vectorl * a)303 static void vectorl_inverse_ntt(vectorl *a) {
304 for (int i = 0; i < L; i++) {
305 scalar_inverse_ntt(&a->v[i]);
306 }
307 }
308
matrix_mult(vectork * out,const matrix * m,const vectorl * a)309 static void matrix_mult(vectork *out, const matrix *m, const vectorl *a) {
310 vectork_zero(out);
311 for (int i = 0; i < K; i++) {
312 for (int j = 0; j < L; j++) {
313 scalar product;
314 scalar_mult(&product, &m->v[i][j], &a->v[j]);
315 scalar_add(&out->v[i], &out->v[i], &product);
316 }
317 }
318 }
319
320 /* Rounding & hints */
321
322 // FIPS 204, Algorithm 29 (`Power2Round`).
power2_round(uint32_t * r1,uint32_t * r0,uint32_t r)323 static void power2_round(uint32_t *r1, uint32_t *r0, uint32_t r) {
324 *r1 = r >> kDroppedBits;
325 *r0 = r - (*r1 << kDroppedBits);
326
327 uint32_t r0_adjusted = reduce_once(kPrime + *r0 - (1 << kDroppedBits));
328 uint32_t r1_adjusted = *r1 + 1;
329
330 // Mask is set iff r0 > 2^(dropped_bits - 1).
331 crypto_word_t mask =
332 constant_time_lt_w((uint32_t)(1 << (kDroppedBits - 1)), *r0);
333 // r0 = mask ? r0_adjusted : r0
334 *r0 = constant_time_select_int(mask, r0_adjusted, *r0);
335 // r1 = mask ? r1_adjusted : r1
336 *r1 = constant_time_select_int(mask, r1_adjusted, *r1);
337 }
338
339 // Scale back previously rounded value.
scale_power2_round(uint32_t * out,uint32_t r1)340 static void scale_power2_round(uint32_t *out, uint32_t r1) {
341 // Pre-condition: 0 <= r1 <= 2^10 - 1
342 *out = r1 << kDroppedBits;
343 // Post-condition: 0 <= out <= 2^23 - 2^13 = kPrime - 1
344 assert(*out < kPrime);
345 }
346
347 // FIPS 204, Algorithm 31 (`HighBits`).
high_bits(uint32_t x)348 static uint32_t high_bits(uint32_t x) {
349 // Reference description (given 0 <= x < q):
350 //
351 // ```
352 // int32_t r0 = x mod+- (2 * kGamma2);
353 // if (x - r0 == q - 1) {
354 // return 0;
355 // } else {
356 // return (x - r0) / (2 * kGamma2);
357 // }
358 // ```
359 //
360 // Below is the formula taken from the reference implementation.
361 //
362 // Here, kGamma2 == 2^18 - 2^8
363 // This returns ((ceil(x / 2^7) * (2^10 + 1) + 2^21) / 2^22) mod 2^4
364 uint32_t r1 = (x + 127) >> 7;
365 r1 = (r1 * 1025 + (1 << 21)) >> 22;
366 r1 &= 15;
367 return r1;
368 }
369
370 // FIPS 204, Algorithm 30 (`Decompose`).
decompose(uint32_t * r1,int32_t * r0,uint32_t r)371 static void decompose(uint32_t *r1, int32_t *r0, uint32_t r) {
372 *r1 = high_bits(r);
373
374 *r0 = r;
375 *r0 -= *r1 * 2 * (int32_t)kGamma2;
376 *r0 -= (((int32_t)kHalfPrime - *r0) >> 31) & (int32_t)kPrime;
377 }
378
379 // FIPS 204, Algorithm 32 (`LowBits`).
low_bits(uint32_t x)380 static int32_t low_bits(uint32_t x) {
381 uint32_t r1;
382 int32_t r0;
383 decompose(&r1, &r0, x);
384 return r0;
385 }
386
387 // FIPS 204, Algorithm 33 (`MakeHint`).
make_hint(uint32_t ct0,uint32_t cs2,uint32_t w)388 static int32_t make_hint(uint32_t ct0, uint32_t cs2, uint32_t w) {
389 uint32_t r_plus_z = reduce_once(kPrime + w - cs2);
390 uint32_t r = reduce_once(r_plus_z + ct0);
391 return high_bits(r) != high_bits(r_plus_z);
392 }
393
394 // FIPS 204, Algorithm 34 (`UseHint`).
use_hint_vartime(uint32_t h,uint32_t r)395 static uint32_t use_hint_vartime(uint32_t h, uint32_t r) {
396 uint32_t r1;
397 int32_t r0;
398 decompose(&r1, &r0, r);
399
400 if (h) {
401 if (r0 > 0) {
402 return (r1 + 1) & 15;
403 } else {
404 return (r1 - 1) & 15;
405 }
406 } else {
407 return r1;
408 }
409 }
410
scalar_power2_round(scalar * s1,scalar * s0,const scalar * s)411 static void scalar_power2_round(scalar *s1, scalar *s0, const scalar *s) {
412 for (int i = 0; i < DEGREE; i++) {
413 power2_round(&s1->c[i], &s0->c[i], s->c[i]);
414 }
415 }
416
scalar_scale_power2_round(scalar * out,const scalar * in)417 static void scalar_scale_power2_round(scalar *out, const scalar *in) {
418 for (int i = 0; i < DEGREE; i++) {
419 scale_power2_round(&out->c[i], in->c[i]);
420 }
421 }
422
scalar_high_bits(scalar * out,const scalar * in)423 static void scalar_high_bits(scalar *out, const scalar *in) {
424 for (int i = 0; i < DEGREE; i++) {
425 out->c[i] = high_bits(in->c[i]);
426 }
427 }
428
scalar_low_bits(scalar * out,const scalar * in)429 static void scalar_low_bits(scalar *out, const scalar *in) {
430 for (int i = 0; i < DEGREE; i++) {
431 out->c[i] = low_bits(in->c[i]);
432 }
433 }
434
scalar_max(uint32_t * max,const scalar * s)435 static void scalar_max(uint32_t *max, const scalar *s) {
436 for (int i = 0; i < DEGREE; i++) {
437 uint32_t abs = abs_mod_prime(s->c[i]);
438 *max = maximum(*max, abs);
439 }
440 }
441
scalar_max_signed(uint32_t * max,const scalar * s)442 static void scalar_max_signed(uint32_t *max, const scalar *s) {
443 for (int i = 0; i < DEGREE; i++) {
444 uint32_t abs = abs_signed(s->c[i]);
445 *max = maximum(*max, abs);
446 }
447 }
448
scalar_make_hint(scalar * out,const scalar * ct0,const scalar * cs2,const scalar * w)449 static void scalar_make_hint(scalar *out, const scalar *ct0, const scalar *cs2,
450 const scalar *w) {
451 for (int i = 0; i < DEGREE; i++) {
452 out->c[i] = make_hint(ct0->c[i], cs2->c[i], w->c[i]);
453 }
454 }
455
scalar_use_hint_vartime(scalar * out,const scalar * h,const scalar * r)456 static void scalar_use_hint_vartime(scalar *out, const scalar *h,
457 const scalar *r) {
458 for (int i = 0; i < DEGREE; i++) {
459 out->c[i] = use_hint_vartime(h->c[i], r->c[i]);
460 }
461 }
462
vectork_power2_round(vectork * t1,vectork * t0,const vectork * t)463 static void vectork_power2_round(vectork *t1, vectork *t0, const vectork *t) {
464 for (int i = 0; i < K; i++) {
465 scalar_power2_round(&t1->v[i], &t0->v[i], &t->v[i]);
466 }
467 }
468
vectork_scale_power2_round(vectork * out,const vectork * in)469 static void vectork_scale_power2_round(vectork *out, const vectork *in) {
470 for (int i = 0; i < K; i++) {
471 scalar_scale_power2_round(&out->v[i], &in->v[i]);
472 }
473 }
474
vectork_high_bits(vectork * out,const vectork * in)475 static void vectork_high_bits(vectork *out, const vectork *in) {
476 for (int i = 0; i < K; i++) {
477 scalar_high_bits(&out->v[i], &in->v[i]);
478 }
479 }
480
vectork_low_bits(vectork * out,const vectork * in)481 static void vectork_low_bits(vectork *out, const vectork *in) {
482 for (int i = 0; i < K; i++) {
483 scalar_low_bits(&out->v[i], &in->v[i]);
484 }
485 }
486
vectork_max(const vectork * a)487 static uint32_t vectork_max(const vectork *a) {
488 uint32_t max = 0;
489 for (int i = 0; i < K; i++) {
490 scalar_max(&max, &a->v[i]);
491 }
492 return max;
493 }
494
vectork_max_signed(const vectork * a)495 static uint32_t vectork_max_signed(const vectork *a) {
496 uint32_t max = 0;
497 for (int i = 0; i < K; i++) {
498 scalar_max_signed(&max, &a->v[i]);
499 }
500 return max;
501 }
502
503 // The input vector contains only zeroes and ones.
vectork_count_ones(const vectork * a)504 static size_t vectork_count_ones(const vectork *a) {
505 size_t count = 0;
506 for (int i = 0; i < K; i++) {
507 for (int j = 0; j < DEGREE; j++) {
508 count += a->v[i].c[j];
509 }
510 }
511 return count;
512 }
513
vectork_make_hint(vectork * out,const vectork * ct0,const vectork * cs2,const vectork * w)514 static void vectork_make_hint(vectork *out, const vectork *ct0,
515 const vectork *cs2, const vectork *w) {
516 for (int i = 0; i < K; i++) {
517 scalar_make_hint(&out->v[i], &ct0->v[i], &cs2->v[i], &w->v[i]);
518 }
519 }
520
vectork_use_hint_vartime(vectork * out,const vectork * h,const vectork * r)521 static void vectork_use_hint_vartime(vectork *out, const vectork *h,
522 const vectork *r) {
523 for (int i = 0; i < K; i++) {
524 scalar_use_hint_vartime(&out->v[i], &h->v[i], &r->v[i]);
525 }
526 }
527
vectorl_max(const vectorl * a)528 static uint32_t vectorl_max(const vectorl *a) {
529 uint32_t max = 0;
530 for (int i = 0; i < L; i++) {
531 scalar_max(&max, &a->v[i]);
532 }
533 return max;
534 }
535
536 /* Bit packing */
537
538 static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
539 0x1f, 0x3f, 0x7f, 0xff};
540
541 // FIPS 204, Algorithm 10 (`SimpleBitPack`).
scalar_encode(uint8_t * out,const scalar * s,int bits)542 static void scalar_encode(uint8_t *out, const scalar *s, int bits) {
543 assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
544
545 uint8_t out_byte = 0;
546 int out_byte_bits = 0;
547
548 for (int i = 0; i < DEGREE; i++) {
549 uint32_t element = s->c[i];
550 int element_bits_done = 0;
551
552 while (element_bits_done < bits) {
553 int chunk_bits = bits - element_bits_done;
554 int out_bits_remaining = 8 - out_byte_bits;
555 if (chunk_bits >= out_bits_remaining) {
556 chunk_bits = out_bits_remaining;
557 out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
558 *out = out_byte;
559 out++;
560 out_byte_bits = 0;
561 out_byte = 0;
562 } else {
563 out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
564 out_byte_bits += chunk_bits;
565 }
566
567 element_bits_done += chunk_bits;
568 element >>= chunk_bits;
569 }
570 }
571
572 if (out_byte_bits > 0) {
573 *out = out_byte;
574 }
575 }
576
577 // FIPS 204, Algorithm 11 (`BitPack`).
scalar_encode_signed(uint8_t * out,const scalar * s,int bits,uint32_t max)578 static void scalar_encode_signed(uint8_t *out, const scalar *s, int bits,
579 uint32_t max) {
580 assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
581
582 uint8_t out_byte = 0;
583 int out_byte_bits = 0;
584
585 for (int i = 0; i < DEGREE; i++) {
586 uint32_t element = reduce_once(kPrime + max - s->c[i]);
587 declassify_assert(element <= 2 * max);
588 int element_bits_done = 0;
589
590 while (element_bits_done < bits) {
591 int chunk_bits = bits - element_bits_done;
592 int out_bits_remaining = 8 - out_byte_bits;
593 if (chunk_bits >= out_bits_remaining) {
594 chunk_bits = out_bits_remaining;
595 out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
596 *out = out_byte;
597 out++;
598 out_byte_bits = 0;
599 out_byte = 0;
600 } else {
601 out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
602 out_byte_bits += chunk_bits;
603 }
604
605 element_bits_done += chunk_bits;
606 element >>= chunk_bits;
607 }
608 }
609
610 if (out_byte_bits > 0) {
611 *out = out_byte;
612 }
613 }
614
615 // FIPS 204, Algorithm 12 (`SimpleBitUnpack`).
scalar_decode(scalar * out,const uint8_t * in,int bits)616 static void scalar_decode(scalar *out, const uint8_t *in, int bits) {
617 assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
618
619 uint8_t in_byte = 0;
620 int in_byte_bits_left = 0;
621
622 for (int i = 0; i < DEGREE; i++) {
623 uint32_t element = 0;
624 int element_bits_done = 0;
625
626 while (element_bits_done < bits) {
627 if (in_byte_bits_left == 0) {
628 in_byte = *in;
629 in++;
630 in_byte_bits_left = 8;
631 }
632
633 int chunk_bits = bits - element_bits_done;
634 if (chunk_bits > in_byte_bits_left) {
635 chunk_bits = in_byte_bits_left;
636 }
637
638 element |= (in_byte & kMasks[chunk_bits - 1]) << element_bits_done;
639 in_byte_bits_left -= chunk_bits;
640 in_byte >>= chunk_bits;
641
642 element_bits_done += chunk_bits;
643 }
644
645 out->c[i] = element;
646 }
647 }
648
649 // FIPS 204, Algorithm 13 (`BitUnpack`).
scalar_decode_signed(scalar * out,const uint8_t * in,int bits,uint32_t max)650 static int scalar_decode_signed(scalar *out, const uint8_t *in, int bits,
651 uint32_t max) {
652 assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
653
654 uint8_t in_byte = 0;
655 int in_byte_bits_left = 0;
656
657 for (int i = 0; i < DEGREE; i++) {
658 uint32_t element = 0;
659 int element_bits_done = 0;
660
661 while (element_bits_done < bits) {
662 if (in_byte_bits_left == 0) {
663 in_byte = *in;
664 in++;
665 in_byte_bits_left = 8;
666 }
667
668 int chunk_bits = bits - element_bits_done;
669 if (chunk_bits > in_byte_bits_left) {
670 chunk_bits = in_byte_bits_left;
671 }
672
673 element |= (in_byte & kMasks[chunk_bits - 1]) << element_bits_done;
674 in_byte_bits_left -= chunk_bits;
675 in_byte >>= chunk_bits;
676
677 element_bits_done += chunk_bits;
678 }
679
680 // This may be only out of range in cases of invalid input, in which case it
681 // is okay to leak the value. This function is also called with secret
682 // input during signing, in |scalar_sample_mask|. However, in that case
683 // (and in any case when |max| is a power of two), this case is impossible.
684 if (constant_time_declassify_int(element > 2 * max)) {
685 return 0;
686 }
687 out->c[i] = reduce_once(kPrime + max - element);
688 }
689
690 return 1;
691 }
692
693 /* Expansion functions */
694
695 // FIPS 204, Algorithm 24 (`RejNTTPoly`).
696 //
697 // Rejection samples a Keccak stream to get uniformly distributed elements. This
698 // is used for matrix expansion and only operates on public inputs.
scalar_from_keccak_vartime(scalar * out,const uint8_t derived_seed[RHO_BYTES+2])699 static void scalar_from_keccak_vartime(
700 scalar *out, const uint8_t derived_seed[RHO_BYTES + 2]) {
701 struct BORINGSSL_keccak_st keccak_ctx;
702 BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake128);
703 BORINGSSL_keccak_absorb(&keccak_ctx, derived_seed, RHO_BYTES + 2);
704 assert(keccak_ctx.squeeze_offset == 0);
705 assert(keccak_ctx.rate_bytes == 168);
706 static_assert(168 % 3 == 0, "block and coefficient boundaries do not align");
707
708 int done = 0;
709 while (done < DEGREE) {
710 uint8_t block[168];
711 BORINGSSL_keccak_squeeze(&keccak_ctx, block, sizeof(block));
712 for (size_t i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
713 // FIPS 204, Algorithm 8 (`CoeffFromThreeBytes`).
714 uint32_t value = (uint32_t)block[i] | ((uint32_t)block[i + 1] << 8) |
715 (((uint32_t)block[i + 2] & 0x7f) << 16);
716 if (value < kPrime) {
717 out->c[done++] = value;
718 }
719 }
720 }
721 }
722
723 // FIPS 204, Algorithm 25 (`RejBoundedPoly`).
scalar_uniform_eta_4(scalar * out,const uint8_t derived_seed[SIGMA_BYTES+2])724 static void scalar_uniform_eta_4(scalar *out,
725 const uint8_t derived_seed[SIGMA_BYTES + 2]) {
726 static_assert(ETA == 4, "This implementation is specialized for ETA == 4");
727
728 struct BORINGSSL_keccak_st keccak_ctx;
729 BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
730 BORINGSSL_keccak_absorb(&keccak_ctx, derived_seed, SIGMA_BYTES + 2);
731 assert(keccak_ctx.squeeze_offset == 0);
732 assert(keccak_ctx.rate_bytes == 136);
733
734 int done = 0;
735 while (done < DEGREE) {
736 uint8_t block[136];
737 BORINGSSL_keccak_squeeze(&keccak_ctx, block, sizeof(block));
738 for (size_t i = 0; i < sizeof(block) && done < DEGREE; ++i) {
739 uint32_t t0 = block[i] & 0x0F;
740 uint32_t t1 = block[i] >> 4;
741 // FIPS 204, Algorithm 9 (`CoefFromHalfByte`). Although both the input and
742 // output here are secret, it is OK to leak when we rejected a byte.
743 // Individual bytes of the SHAKE-256 stream are (indistiguishable from)
744 // independent of each other and the original seed, so leaking information
745 // about the rejected bytes does not reveal the input or output.
746 if (constant_time_declassify_int(t0 < 9)) {
747 out->c[done++] = reduce_once(kPrime + ETA - t0);
748 }
749 if (done < DEGREE && constant_time_declassify_int(t1 < 9)) {
750 out->c[done++] = reduce_once(kPrime + ETA - t1);
751 }
752 }
753 }
754 }
755
756 // FIPS 204, Algorithm 28 (`ExpandMask`).
scalar_sample_mask(scalar * out,const uint8_t derived_seed[RHO_PRIME_BYTES+2])757 static void scalar_sample_mask(
758 scalar *out, const uint8_t derived_seed[RHO_PRIME_BYTES + 2]) {
759 uint8_t buf[640];
760 BORINGSSL_keccak(buf, sizeof(buf), derived_seed, RHO_PRIME_BYTES + 2,
761 boringssl_shake256);
762
763 // Note: Decoding 20 bits into (-2^19, 2^19] cannot fail.
764 scalar_decode_signed(out, buf, 20, 1 << 19);
765 }
766
767 // FIPS 204, Algorithm 23 (`SampleInBall`).
scalar_sample_in_ball_vartime(scalar * out,const uint8_t * seed,int len)768 static void scalar_sample_in_ball_vartime(scalar *out, const uint8_t *seed,
769 int len) {
770 assert(len == 32);
771
772 struct BORINGSSL_keccak_st keccak_ctx;
773 BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
774 BORINGSSL_keccak_absorb(&keccak_ctx, seed, len);
775 assert(keccak_ctx.squeeze_offset == 0);
776 assert(keccak_ctx.rate_bytes == 136);
777
778 uint8_t block[136];
779 BORINGSSL_keccak_squeeze(&keccak_ctx, block, sizeof(block));
780
781 uint64_t signs = CRYPTO_load_u64_le(block);
782 int offset = 8;
783 // SampleInBall implements a Fisher–Yates shuffle, which unavoidably leaks
784 // where the zeros are by memory access pattern. Although this leak happens
785 // before bad signatures are rejected, this is safe. See
786 // https://boringssl-review.googlesource.com/c/boringssl/+/67747/comment/8d8f01ac_70af3f21/
787 CONSTTIME_DECLASSIFY(block + offset, sizeof(block) - offset);
788
789 OPENSSL_memset(out, 0, sizeof(*out));
790 for (size_t i = DEGREE - TAU; i < DEGREE; i++) {
791 size_t byte;
792 for (;;) {
793 if (offset == 136) {
794 BORINGSSL_keccak_squeeze(&keccak_ctx, block, sizeof(block));
795 // See above.
796 CONSTTIME_DECLASSIFY(block, sizeof(block));
797 offset = 0;
798 }
799
800 byte = block[offset++];
801 if (byte <= i) {
802 break;
803 }
804 }
805
806 out->c[i] = out->c[byte];
807 out->c[byte] = reduce_once(kPrime + 1 - 2 * (signs & 1));
808 signs >>= 1;
809 }
810 }
811
812 // FIPS 204, Algorithm 26 (`ExpandA`).
matrix_expand(matrix * out,const uint8_t rho[RHO_BYTES])813 static void matrix_expand(matrix *out, const uint8_t rho[RHO_BYTES]) {
814 static_assert(K <= 0x100, "K must fit in 8 bits");
815 static_assert(L <= 0x100, "L must fit in 8 bits");
816
817 uint8_t derived_seed[RHO_BYTES + 2];
818 OPENSSL_memcpy(derived_seed, rho, RHO_BYTES);
819 for (int i = 0; i < K; i++) {
820 for (int j = 0; j < L; j++) {
821 derived_seed[RHO_BYTES + 1] = i;
822 derived_seed[RHO_BYTES] = j;
823 scalar_from_keccak_vartime(&out->v[i][j], derived_seed);
824 }
825 }
826 }
827
828 // FIPS 204, Algorithm 27 (`ExpandS`).
vector_expand_short(vectorl * s1,vectork * s2,const uint8_t sigma[SIGMA_BYTES])829 static void vector_expand_short(vectorl *s1, vectork *s2,
830 const uint8_t sigma[SIGMA_BYTES]) {
831 static_assert(K <= 0x100, "K must fit in 8 bits");
832 static_assert(L <= 0x100, "L must fit in 8 bits");
833 static_assert(K + L <= 0x100, "K+L must fit in 8 bits");
834
835 uint8_t derived_seed[SIGMA_BYTES + 2];
836 OPENSSL_memcpy(derived_seed, sigma, SIGMA_BYTES);
837 derived_seed[SIGMA_BYTES] = 0;
838 derived_seed[SIGMA_BYTES + 1] = 0;
839 for (int i = 0; i < L; i++) {
840 scalar_uniform_eta_4(&s1->v[i], derived_seed);
841 ++derived_seed[SIGMA_BYTES];
842 }
843 for (int i = 0; i < K; i++) {
844 scalar_uniform_eta_4(&s2->v[i], derived_seed);
845 ++derived_seed[SIGMA_BYTES];
846 }
847 }
848
849 // FIPS 204, Algorithm 28 (`ExpandMask`).
vectorl_expand_mask(vectorl * out,const uint8_t seed[RHO_PRIME_BYTES],size_t kappa)850 static void vectorl_expand_mask(vectorl *out,
851 const uint8_t seed[RHO_PRIME_BYTES],
852 size_t kappa) {
853 assert(kappa + L <= 0x10000);
854
855 uint8_t derived_seed[RHO_PRIME_BYTES + 2];
856 OPENSSL_memcpy(derived_seed, seed, RHO_PRIME_BYTES);
857 for (int i = 0; i < L; i++) {
858 size_t index = kappa + i;
859 derived_seed[RHO_PRIME_BYTES] = index & 0xFF;
860 derived_seed[RHO_PRIME_BYTES + 1] = (index >> 8) & 0xFF;
861 scalar_sample_mask(&out->v[i], derived_seed);
862 }
863 }
864
865 /* Encoding */
866
867 // FIPS 204, Algorithm 10 (`SimpleBitPack`).
868 //
869 // Encodes an entire vector into 32*K*|bits| bytes. Note that since 256 (DEGREE)
870 // is divisible by 8, the individual vector entries will always fill a whole
871 // number of bytes, so we do not need to worry about bit packing here.
vectork_encode(uint8_t * out,const vectork * a,int bits)872 static void vectork_encode(uint8_t *out, const vectork *a, int bits) {
873 for (int i = 0; i < K; i++) {
874 scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits);
875 }
876 }
877
878 // FIPS 204, Algorithm 12 (`SimpleBitUnpack`).
vectork_decode(vectork * out,const uint8_t * in,int bits)879 static void vectork_decode(vectork *out, const uint8_t *in, int bits) {
880 for (int i = 0; i < K; i++) {
881 scalar_decode(&out->v[i], in + i * bits * DEGREE / 8, bits);
882 }
883 }
884
vectork_encode_signed(uint8_t * out,const vectork * a,int bits,uint32_t max)885 static void vectork_encode_signed(uint8_t *out, const vectork *a, int bits,
886 uint32_t max) {
887 for (int i = 0; i < K; i++) {
888 scalar_encode_signed(out + i * bits * DEGREE / 8, &a->v[i], bits, max);
889 }
890 }
891
vectork_decode_signed(vectork * out,const uint8_t * in,int bits,uint32_t max)892 static int vectork_decode_signed(vectork *out, const uint8_t *in, int bits,
893 uint32_t max) {
894 for (int i = 0; i < K; i++) {
895 if (!scalar_decode_signed(&out->v[i], in + i * bits * DEGREE / 8, bits,
896 max)) {
897 return 0;
898 }
899 }
900 return 1;
901 }
902
903 // FIPS 204, Algorithm 11 (`BitPack`).
904 //
905 // Encodes an entire vector into 32*L*|bits| bytes. Note that since 256 (DEGREE)
906 // is divisible by 8, the individual vector entries will always fill a whole
907 // number of bytes, so we do not need to worry about bit packing here.
vectorl_encode_signed(uint8_t * out,const vectorl * a,int bits,uint32_t max)908 static void vectorl_encode_signed(uint8_t *out, const vectorl *a, int bits,
909 uint32_t max) {
910 for (int i = 0; i < L; i++) {
911 scalar_encode_signed(out + i * bits * DEGREE / 8, &a->v[i], bits, max);
912 }
913 }
914
vectorl_decode_signed(vectorl * out,const uint8_t * in,int bits,uint32_t max)915 static int vectorl_decode_signed(vectorl *out, const uint8_t *in, int bits,
916 uint32_t max) {
917 for (int i = 0; i < L; i++) {
918 if (!scalar_decode_signed(&out->v[i], in + i * bits * DEGREE / 8, bits,
919 max)) {
920 return 0;
921 }
922 }
923 return 1;
924 }
925
926 // FIPS 204, Algorithm 22 (`w1Encode`).
927 //
928 // The output must point to an array of 128*K bytes.
w1_encode(uint8_t * out,const vectork * w1)929 static void w1_encode(uint8_t *out, const vectork *w1) {
930 vectork_encode(out, w1, 4);
931 }
932
933 // FIPS 204, Algorithm 14 (`HintBitPack`).
hint_bit_pack(uint8_t * out,const vectork * h)934 static void hint_bit_pack(uint8_t *out, const vectork *h) {
935 OPENSSL_memset(out, 0, OMEGA + K);
936 int index = 0;
937 for (int i = 0; i < K; i++) {
938 for (int j = 0; j < DEGREE; j++) {
939 if (h->v[i].c[j]) {
940 out[index++] = j;
941 }
942 }
943 out[OMEGA + i] = index;
944 }
945 }
946
947 // FIPS 204, Algorithm 15 (`HintBitUnpack`).
hint_bit_unpack(vectork * h,const uint8_t * in)948 static int hint_bit_unpack(vectork *h, const uint8_t *in) {
949 vectork_zero(h);
950 int index = 0;
951 for (int i = 0; i < K; i++) {
952 int limit = in[OMEGA + i];
953 if (limit < index || limit > OMEGA) {
954 return 0;
955 }
956
957 int last = -1;
958 while (index < limit) {
959 int byte = in[index++];
960 if (last >= 0 && byte <= last) {
961 return 0;
962 }
963 last = byte;
964 h->v[i].c[byte] = 1;
965 }
966 }
967 for (; index < OMEGA; index++) {
968 if (in[index] != 0) {
969 return 0;
970 }
971 }
972 return 1;
973 }
974
975 struct public_key {
976 uint8_t rho[RHO_BYTES];
977 vectork t1;
978 // Pre-cached value(s).
979 uint8_t public_key_hash[TR_BYTES];
980 };
981
982 struct private_key {
983 uint8_t rho[RHO_BYTES];
984 uint8_t k[K_BYTES];
985 uint8_t public_key_hash[TR_BYTES];
986 vectorl s1;
987 vectork s2;
988 vectork t0;
989 };
990
991 struct signature {
992 uint8_t c_tilde[2 * LAMBDA_BYTES];
993 vectorl z;
994 vectork h;
995 };
996
997 // FIPS 204, Algorithm 16 (`pkEncode`).
dilithium_marshal_public_key(CBB * out,const struct public_key * pub)998 static int dilithium_marshal_public_key(CBB *out,
999 const struct public_key *pub) {
1000 if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) {
1001 return 0;
1002 }
1003
1004 uint8_t *vectork_output;
1005 if (!CBB_add_space(out, &vectork_output, 320 * K)) {
1006 return 0;
1007 }
1008 vectork_encode(vectork_output, &pub->t1, 10);
1009
1010 return 1;
1011 }
1012
1013 // FIPS 204, Algorithm 17 (`pkDecode`).
dilithium_parse_public_key(struct public_key * pub,CBS * in)1014 static int dilithium_parse_public_key(struct public_key *pub, CBS *in) {
1015 if (!CBS_copy_bytes(in, pub->rho, sizeof(pub->rho))) {
1016 return 0;
1017 }
1018
1019 CBS t1_bytes;
1020 if (!CBS_get_bytes(in, &t1_bytes, 320 * K)) {
1021 return 0;
1022 }
1023 vectork_decode(&pub->t1, CBS_data(&t1_bytes), 10);
1024
1025 return 1;
1026 }
1027
1028 // FIPS 204, Algorithm 18 (`skEncode`).
dilithium_marshal_private_key(CBB * out,const struct private_key * priv)1029 static int dilithium_marshal_private_key(CBB *out,
1030 const struct private_key *priv) {
1031 if (!CBB_add_bytes(out, priv->rho, sizeof(priv->rho)) ||
1032 !CBB_add_bytes(out, priv->k, sizeof(priv->k)) ||
1033 !CBB_add_bytes(out, priv->public_key_hash,
1034 sizeof(priv->public_key_hash))) {
1035 return 0;
1036 }
1037
1038 uint8_t *vectorl_output;
1039 if (!CBB_add_space(out, &vectorl_output, 128 * L)) {
1040 return 0;
1041 }
1042 vectorl_encode_signed(vectorl_output, &priv->s1, 4, ETA);
1043
1044 uint8_t *vectork_output;
1045 if (!CBB_add_space(out, &vectork_output, 128 * K)) {
1046 return 0;
1047 }
1048 vectork_encode_signed(vectork_output, &priv->s2, 4, ETA);
1049
1050 if (!CBB_add_space(out, &vectork_output, 416 * K)) {
1051 return 0;
1052 }
1053 vectork_encode_signed(vectork_output, &priv->t0, 13, 1 << 12);
1054
1055 return 1;
1056 }
1057
1058 // FIPS 204, Algorithm 19 (`skDecode`).
dilithium_parse_private_key(struct private_key * priv,CBS * in)1059 static int dilithium_parse_private_key(struct private_key *priv, CBS *in) {
1060 CBS s1_bytes;
1061 CBS s2_bytes;
1062 CBS t0_bytes;
1063 if (!CBS_copy_bytes(in, priv->rho, sizeof(priv->rho)) ||
1064 !CBS_copy_bytes(in, priv->k, sizeof(priv->k)) ||
1065 !CBS_copy_bytes(in, priv->public_key_hash,
1066 sizeof(priv->public_key_hash)) ||
1067 !CBS_get_bytes(in, &s1_bytes, 128 * L) ||
1068 !vectorl_decode_signed(&priv->s1, CBS_data(&s1_bytes), 4, ETA) ||
1069 !CBS_get_bytes(in, &s2_bytes, 128 * K) ||
1070 !vectork_decode_signed(&priv->s2, CBS_data(&s2_bytes), 4, ETA) ||
1071 !CBS_get_bytes(in, &t0_bytes, 416 * K) ||
1072 // Note: Decoding 13 bits into (-2^12, 2^12] cannot fail.
1073 !vectork_decode_signed(&priv->t0, CBS_data(&t0_bytes), 13, 1 << 12)) {
1074 return 0;
1075 }
1076
1077 return 1;
1078 }
1079
1080 // FIPS 204, Algorithm 20 (`sigEncode`).
dilithium_marshal_signature(CBB * out,const struct signature * sign)1081 static int dilithium_marshal_signature(CBB *out, const struct signature *sign) {
1082 if (!CBB_add_bytes(out, sign->c_tilde, sizeof(sign->c_tilde))) {
1083 return 0;
1084 }
1085
1086 uint8_t *vectorl_output;
1087 if (!CBB_add_space(out, &vectorl_output, 640 * L)) {
1088 return 0;
1089 }
1090 vectorl_encode_signed(vectorl_output, &sign->z, 20, 1 << 19);
1091
1092 uint8_t *hint_output;
1093 if (!CBB_add_space(out, &hint_output, OMEGA + K)) {
1094 return 0;
1095 }
1096 hint_bit_pack(hint_output, &sign->h);
1097
1098 return 1;
1099 }
1100
1101 // FIPS 204, Algorithm 21 (`sigDecode`).
dilithium_parse_signature(struct signature * sign,CBS * in)1102 static int dilithium_parse_signature(struct signature *sign, CBS *in) {
1103 CBS z_bytes;
1104 CBS hint_bytes;
1105 if (!CBS_copy_bytes(in, sign->c_tilde, sizeof(sign->c_tilde)) ||
1106 !CBS_get_bytes(in, &z_bytes, 640 * L) ||
1107 // Note: Decoding 20 bits into (-2^19, 2^19] cannot fail.
1108 !vectorl_decode_signed(&sign->z, CBS_data(&z_bytes), 20, 1 << 19) ||
1109 !CBS_get_bytes(in, &hint_bytes, OMEGA + K) ||
1110 !hint_bit_unpack(&sign->h, CBS_data(&hint_bytes))) {
1111 return 0;
1112 };
1113
1114 return 1;
1115 }
1116
private_key_from_external(const struct DILITHIUM_private_key * external)1117 static struct private_key *private_key_from_external(
1118 const struct DILITHIUM_private_key *external) {
1119 static_assert(
1120 sizeof(struct DILITHIUM_private_key) == sizeof(struct private_key),
1121 "Kyber private key size incorrect");
1122 static_assert(
1123 alignof(struct DILITHIUM_private_key) == alignof(struct private_key),
1124 "Kyber private key align incorrect");
1125 return (struct private_key *)external;
1126 }
1127
public_key_from_external(const struct DILITHIUM_public_key * external)1128 static struct public_key *public_key_from_external(
1129 const struct DILITHIUM_public_key *external) {
1130 static_assert(
1131 sizeof(struct DILITHIUM_public_key) == sizeof(struct public_key),
1132 "Dilithium public key size incorrect");
1133 static_assert(
1134 alignof(struct DILITHIUM_public_key) == alignof(struct public_key),
1135 "Dilithium public key align incorrect");
1136 return (struct public_key *)external;
1137 }
1138
1139 /* API */
1140
1141 // Calls |DILITHIUM_generate_key_external_entropy| with random bytes from
1142 // |RAND_bytes|. Returns 1 on success and 0 on failure.
DILITHIUM_generate_key(uint8_t out_encoded_public_key[DILITHIUM_PUBLIC_KEY_BYTES],struct DILITHIUM_private_key * out_private_key)1143 int DILITHIUM_generate_key(
1144 uint8_t out_encoded_public_key[DILITHIUM_PUBLIC_KEY_BYTES],
1145 struct DILITHIUM_private_key *out_private_key) {
1146 uint8_t entropy[DILITHIUM_GENERATE_KEY_ENTROPY];
1147 RAND_bytes(entropy, sizeof(entropy));
1148 return DILITHIUM_generate_key_external_entropy(out_encoded_public_key,
1149 out_private_key, entropy);
1150 }
1151
1152 // FIPS 204, Algorithm 1 (`ML-DSA.KeyGen`). Returns 1 on success and 0 on
1153 // failure.
DILITHIUM_generate_key_external_entropy(uint8_t out_encoded_public_key[DILITHIUM_PUBLIC_KEY_BYTES],struct DILITHIUM_private_key * out_private_key,const uint8_t entropy[DILITHIUM_GENERATE_KEY_ENTROPY])1154 int DILITHIUM_generate_key_external_entropy(
1155 uint8_t out_encoded_public_key[DILITHIUM_PUBLIC_KEY_BYTES],
1156 struct DILITHIUM_private_key *out_private_key,
1157 const uint8_t entropy[DILITHIUM_GENERATE_KEY_ENTROPY]) {
1158 int ret = 0;
1159
1160 // Intermediate values, allocated on the heap to allow use when there is a
1161 // limited amount of stack.
1162 struct values_st {
1163 struct public_key pub;
1164 matrix a_ntt;
1165 vectorl s1_ntt;
1166 vectork t;
1167 };
1168 struct values_st *values = OPENSSL_malloc(sizeof(*values));
1169 if (values == NULL) {
1170 goto err;
1171 }
1172
1173 struct private_key *priv = private_key_from_external(out_private_key);
1174
1175 uint8_t expanded_seed[RHO_BYTES + SIGMA_BYTES + K_BYTES];
1176 BORINGSSL_keccak(expanded_seed, sizeof(expanded_seed), entropy,
1177 DILITHIUM_GENERATE_KEY_ENTROPY, boringssl_shake256);
1178 const uint8_t *const rho = expanded_seed;
1179 const uint8_t *const sigma = expanded_seed + RHO_BYTES;
1180 const uint8_t *const k = expanded_seed + RHO_BYTES + SIGMA_BYTES;
1181 // rho is public.
1182 CONSTTIME_DECLASSIFY(rho, RHO_BYTES);
1183 OPENSSL_memcpy(values->pub.rho, rho, sizeof(values->pub.rho));
1184 OPENSSL_memcpy(priv->rho, rho, sizeof(priv->rho));
1185 OPENSSL_memcpy(priv->k, k, sizeof(priv->k));
1186
1187 matrix_expand(&values->a_ntt, rho);
1188 vector_expand_short(&priv->s1, &priv->s2, sigma);
1189
1190 OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
1191 vectorl_ntt(&values->s1_ntt);
1192
1193 matrix_mult(&values->t, &values->a_ntt, &values->s1_ntt);
1194 vectork_inverse_ntt(&values->t);
1195 vectork_add(&values->t, &values->t, &priv->s2);
1196
1197 vectork_power2_round(&values->pub.t1, &priv->t0, &values->t);
1198 // t1 is public.
1199 CONSTTIME_DECLASSIFY(&values->pub.t1, sizeof(values->pub.t1));
1200
1201 CBB cbb;
1202 CBB_init_fixed(&cbb, out_encoded_public_key, DILITHIUM_PUBLIC_KEY_BYTES);
1203 if (!dilithium_marshal_public_key(&cbb, &values->pub)) {
1204 goto err;
1205 }
1206
1207 BORINGSSL_keccak(priv->public_key_hash, sizeof(priv->public_key_hash),
1208 out_encoded_public_key, DILITHIUM_PUBLIC_KEY_BYTES,
1209 boringssl_shake256);
1210
1211 ret = 1;
1212 err:
1213 OPENSSL_free(values);
1214 return ret;
1215 }
1216
DILITHIUM_public_from_private(struct DILITHIUM_public_key * out_public_key,const struct DILITHIUM_private_key * private_key)1217 int DILITHIUM_public_from_private(
1218 struct DILITHIUM_public_key *out_public_key,
1219 const struct DILITHIUM_private_key *private_key) {
1220 int ret = 0;
1221
1222 // Intermediate values, allocated on the heap to allow use when there is a
1223 // limited amount of stack.
1224 struct values_st {
1225 matrix a_ntt;
1226 vectorl s1_ntt;
1227 vectork t;
1228 vectork t0;
1229 };
1230 struct values_st *values = OPENSSL_malloc(sizeof(*values));
1231 if (values == NULL) {
1232 goto err;
1233 }
1234
1235 const struct private_key *priv = private_key_from_external(private_key);
1236 struct public_key *pub = public_key_from_external(out_public_key);
1237
1238 OPENSSL_memcpy(pub->rho, priv->rho, sizeof(pub->rho));
1239 OPENSSL_memcpy(pub->public_key_hash, priv->public_key_hash,
1240 sizeof(pub->public_key_hash));
1241
1242 matrix_expand(&values->a_ntt, priv->rho);
1243
1244 OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
1245 vectorl_ntt(&values->s1_ntt);
1246
1247 matrix_mult(&values->t, &values->a_ntt, &values->s1_ntt);
1248 vectork_inverse_ntt(&values->t);
1249 vectork_add(&values->t, &values->t, &priv->s2);
1250
1251 vectork_power2_round(&pub->t1, &values->t0, &values->t);
1252
1253 ret = 1;
1254 err:
1255 OPENSSL_free(values);
1256 return ret;
1257 }
1258
1259 // FIPS 204, Algorithm 2 (`ML-DSA.Sign`). Returns 1 on success and 0 on failure.
dilithium_sign_with_randomizer(uint8_t out_encoded_signature[DILITHIUM_SIGNATURE_BYTES],const struct DILITHIUM_private_key * private_key,const uint8_t * msg,size_t msg_len,const uint8_t randomizer[DILITHIUM_SIGNATURE_RANDOMIZER_BYTES])1260 static int dilithium_sign_with_randomizer(
1261 uint8_t out_encoded_signature[DILITHIUM_SIGNATURE_BYTES],
1262 const struct DILITHIUM_private_key *private_key, const uint8_t *msg,
1263 size_t msg_len,
1264 const uint8_t randomizer[DILITHIUM_SIGNATURE_RANDOMIZER_BYTES]) {
1265 int ret = 0;
1266
1267 const struct private_key *priv = private_key_from_external(private_key);
1268
1269 uint8_t mu[MU_BYTES];
1270 struct BORINGSSL_keccak_st keccak_ctx;
1271 BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
1272 BORINGSSL_keccak_absorb(&keccak_ctx, priv->public_key_hash,
1273 sizeof(priv->public_key_hash));
1274 BORINGSSL_keccak_absorb(&keccak_ctx, msg, msg_len);
1275 BORINGSSL_keccak_squeeze(&keccak_ctx, mu, MU_BYTES);
1276
1277 uint8_t rho_prime[RHO_PRIME_BYTES];
1278 BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
1279 BORINGSSL_keccak_absorb(&keccak_ctx, priv->k, sizeof(priv->k));
1280 BORINGSSL_keccak_absorb(&keccak_ctx, randomizer,
1281 DILITHIUM_SIGNATURE_RANDOMIZER_BYTES);
1282 BORINGSSL_keccak_absorb(&keccak_ctx, mu, MU_BYTES);
1283 BORINGSSL_keccak_squeeze(&keccak_ctx, rho_prime, RHO_PRIME_BYTES);
1284
1285 // Intermediate values, allocated on the heap to allow use when there is a
1286 // limited amount of stack.
1287 struct values_st {
1288 struct signature sign;
1289 vectorl s1_ntt;
1290 vectork s2_ntt;
1291 vectork t0_ntt;
1292 matrix a_ntt;
1293 vectorl y;
1294 vectorl y_ntt;
1295 vectork w;
1296 vectork w1;
1297 vectorl cs1;
1298 vectork cs2;
1299 vectork r0;
1300 vectork ct0;
1301 };
1302 struct values_st *values = OPENSSL_malloc(sizeof(*values));
1303 if (values == NULL) {
1304 goto err;
1305 }
1306 OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
1307 vectorl_ntt(&values->s1_ntt);
1308
1309 OPENSSL_memcpy(&values->s2_ntt, &priv->s2, sizeof(values->s2_ntt));
1310 vectork_ntt(&values->s2_ntt);
1311
1312 OPENSSL_memcpy(&values->t0_ntt, &priv->t0, sizeof(values->t0_ntt));
1313 vectork_ntt(&values->t0_ntt);
1314
1315 matrix_expand(&values->a_ntt, priv->rho);
1316
1317 for (size_t kappa = 0;; kappa += L) {
1318 // TODO(bbe): y only lives long enough to compute y_ntt.
1319 // consider using another vectorl to save memory.
1320 vectorl_expand_mask(&values->y, rho_prime, kappa);
1321
1322 OPENSSL_memcpy(&values->y_ntt, &values->y, sizeof(values->y_ntt));
1323 vectorl_ntt(&values->y_ntt);
1324
1325 // TODO(bbe): w only lives long enough to compute y_ntt.
1326 // consider using another vectork to save memory.
1327 matrix_mult(&values->w, &values->a_ntt, &values->y_ntt);
1328 vectork_inverse_ntt(&values->w);
1329
1330 vectork_high_bits(&values->w1, &values->w);
1331 uint8_t w1_encoded[128 * K];
1332 w1_encode(w1_encoded, &values->w1);
1333
1334 BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
1335 BORINGSSL_keccak_absorb(&keccak_ctx, mu, MU_BYTES);
1336 BORINGSSL_keccak_absorb(&keccak_ctx, w1_encoded, 128 * K);
1337 BORINGSSL_keccak_squeeze(&keccak_ctx, values->sign.c_tilde,
1338 2 * LAMBDA_BYTES);
1339
1340 scalar c_ntt;
1341 scalar_sample_in_ball_vartime(&c_ntt, values->sign.c_tilde, 32);
1342 scalar_ntt(&c_ntt);
1343
1344 vectorl_mult_scalar(&values->cs1, &values->s1_ntt, &c_ntt);
1345 vectorl_inverse_ntt(&values->cs1);
1346 vectork_mult_scalar(&values->cs2, &values->s2_ntt, &c_ntt);
1347 vectork_inverse_ntt(&values->cs2);
1348
1349 vectorl_add(&values->sign.z, &values->y, &values->cs1);
1350
1351 vectork_sub(&values->r0, &values->w, &values->cs2);
1352 vectork_low_bits(&values->r0, &values->r0);
1353
1354 // Leaking the fact that a signature was rejected is fine as the next
1355 // attempt at a signature will be (indistinguishable from) independent of
1356 // this one. Note, however, that we additionally leak which of the two
1357 // branches rejected the signature. Section 5.5 of
1358 // https://pq-crystals.org/dilithium/data/dilithium-specification-round3.pdf
1359 // describes this leak as OK. Note we leak less than what is described by
1360 // the paper; we do not reveal which coefficient violated the bound, and we
1361 // hide which of the |z_max| or |r0_max| bound failed. See also
1362 // https://boringssl-review.googlesource.com/c/boringssl/+/67747/comment/2bbab0fa_d241d35a/
1363 uint32_t z_max = vectorl_max(&values->sign.z);
1364 uint32_t r0_max = vectork_max_signed(&values->r0);
1365 if (constant_time_declassify_w(
1366 constant_time_ge_w(z_max, kGamma1 - BETA) |
1367 constant_time_ge_w(r0_max, kGamma2 - BETA))) {
1368 continue;
1369 }
1370
1371 vectork_mult_scalar(&values->ct0, &values->t0_ntt, &c_ntt);
1372 vectork_inverse_ntt(&values->ct0);
1373 vectork_make_hint(&values->sign.h, &values->ct0, &values->cs2, &values->w);
1374
1375 // See above.
1376 uint32_t ct0_max = vectork_max(&values->ct0);
1377 size_t h_ones = vectork_count_ones(&values->sign.h);
1378 if (constant_time_declassify_w(constant_time_ge_w(ct0_max, kGamma2) |
1379 constant_time_lt_w(OMEGA, h_ones))) {
1380 continue;
1381 }
1382
1383 // Although computed with the private key, the signature is public.
1384 CONSTTIME_DECLASSIFY(values->sign.c_tilde, sizeof(values->sign.c_tilde));
1385 CONSTTIME_DECLASSIFY(&values->sign.z, sizeof(values->sign.z));
1386 CONSTTIME_DECLASSIFY(&values->sign.h, sizeof(values->sign.h));
1387
1388 CBB cbb;
1389 CBB_init_fixed(&cbb, out_encoded_signature, DILITHIUM_SIGNATURE_BYTES);
1390 if (!dilithium_marshal_signature(&cbb, &values->sign)) {
1391 goto err;
1392 }
1393
1394 BSSL_CHECK(CBB_len(&cbb) == DILITHIUM_SIGNATURE_BYTES);
1395 ret = 1;
1396 break;
1397 }
1398
1399 err:
1400 OPENSSL_free(values);
1401 return ret;
1402 }
1403
1404 // Dilithium signature in deterministic mode. Returns 1 on success and 0 on
1405 // failure.
DILITHIUM_sign_deterministic(uint8_t out_encoded_signature[DILITHIUM_SIGNATURE_BYTES],const struct DILITHIUM_private_key * private_key,const uint8_t * msg,size_t msg_len)1406 int DILITHIUM_sign_deterministic(
1407 uint8_t out_encoded_signature[DILITHIUM_SIGNATURE_BYTES],
1408 const struct DILITHIUM_private_key *private_key, const uint8_t *msg,
1409 size_t msg_len) {
1410 uint8_t randomizer[DILITHIUM_SIGNATURE_RANDOMIZER_BYTES];
1411 OPENSSL_memset(randomizer, 0, sizeof(randomizer));
1412 return dilithium_sign_with_randomizer(out_encoded_signature, private_key, msg,
1413 msg_len, randomizer);
1414 }
1415
1416 // Dilithium signature in randomized mode, filling the random bytes with
1417 // |RAND_bytes|. Returns 1 on success and 0 on failure.
DILITHIUM_sign(uint8_t out_encoded_signature[DILITHIUM_SIGNATURE_BYTES],const struct DILITHIUM_private_key * private_key,const uint8_t * msg,size_t msg_len)1418 int DILITHIUM_sign(uint8_t out_encoded_signature[DILITHIUM_SIGNATURE_BYTES],
1419 const struct DILITHIUM_private_key *private_key,
1420 const uint8_t *msg, size_t msg_len) {
1421 uint8_t randomizer[DILITHIUM_SIGNATURE_RANDOMIZER_BYTES];
1422 RAND_bytes(randomizer, sizeof(randomizer));
1423 return dilithium_sign_with_randomizer(out_encoded_signature, private_key, msg,
1424 msg_len, randomizer);
1425 }
1426
1427 // FIPS 204, Algorithm 3 (`ML-DSA.Verify`).
DILITHIUM_verify(const struct DILITHIUM_public_key * public_key,const uint8_t encoded_signature[DILITHIUM_SIGNATURE_BYTES],const uint8_t * msg,size_t msg_len)1428 int DILITHIUM_verify(const struct DILITHIUM_public_key *public_key,
1429 const uint8_t encoded_signature[DILITHIUM_SIGNATURE_BYTES],
1430 const uint8_t *msg, size_t msg_len) {
1431 int ret = 0;
1432
1433 // Intermediate values, allocated on the heap to allow use when there is a
1434 // limited amount of stack.
1435 struct values_st {
1436 struct signature sign;
1437 matrix a_ntt;
1438 vectorl z_ntt;
1439 vectork az_ntt;
1440 vectork t1_ntt;
1441 vectork ct1_ntt;
1442 vectork w_approx;
1443 vectork w1;
1444 };
1445 struct values_st *values = OPENSSL_malloc(sizeof(*values));
1446 if (values == NULL) {
1447 goto err;
1448 }
1449
1450 const struct public_key *pub = public_key_from_external(public_key);
1451
1452 CBS cbs;
1453 CBS_init(&cbs, encoded_signature, DILITHIUM_SIGNATURE_BYTES);
1454 if (!dilithium_parse_signature(&values->sign, &cbs)) {
1455 goto err;
1456 }
1457
1458 matrix_expand(&values->a_ntt, pub->rho);
1459
1460 uint8_t mu[MU_BYTES];
1461 struct BORINGSSL_keccak_st keccak_ctx;
1462 BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
1463 BORINGSSL_keccak_absorb(&keccak_ctx, pub->public_key_hash,
1464 sizeof(pub->public_key_hash));
1465 BORINGSSL_keccak_absorb(&keccak_ctx, msg, msg_len);
1466 BORINGSSL_keccak_squeeze(&keccak_ctx, mu, MU_BYTES);
1467
1468 scalar c_ntt;
1469 scalar_sample_in_ball_vartime(&c_ntt, values->sign.c_tilde, 32);
1470 scalar_ntt(&c_ntt);
1471
1472 OPENSSL_memcpy(&values->z_ntt, &values->sign.z, sizeof(values->z_ntt));
1473 vectorl_ntt(&values->z_ntt);
1474
1475 matrix_mult(&values->az_ntt, &values->a_ntt, &values->z_ntt);
1476
1477 vectork_scale_power2_round(&values->t1_ntt, &pub->t1);
1478 vectork_ntt(&values->t1_ntt);
1479
1480 vectork_mult_scalar(&values->ct1_ntt, &values->t1_ntt, &c_ntt);
1481
1482 vectork_sub(&values->w_approx, &values->az_ntt, &values->ct1_ntt);
1483 vectork_inverse_ntt(&values->w_approx);
1484
1485 vectork_use_hint_vartime(&values->w1, &values->sign.h, &values->w_approx);
1486 uint8_t w1_encoded[128 * K];
1487 w1_encode(w1_encoded, &values->w1);
1488
1489 uint8_t c_tilde[2 * LAMBDA_BYTES];
1490 BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
1491 BORINGSSL_keccak_absorb(&keccak_ctx, mu, MU_BYTES);
1492 BORINGSSL_keccak_absorb(&keccak_ctx, w1_encoded, 128 * K);
1493 BORINGSSL_keccak_squeeze(&keccak_ctx, c_tilde, 2 * LAMBDA_BYTES);
1494
1495 uint32_t z_max = vectorl_max(&values->sign.z);
1496 size_t h_ones = vectork_count_ones(&values->sign.h);
1497 if (z_max < kGamma1 - BETA && h_ones <= OMEGA &&
1498 OPENSSL_memcmp(c_tilde, values->sign.c_tilde, 2 * LAMBDA_BYTES) == 0) {
1499 ret = 1;
1500 }
1501
1502 err:
1503 OPENSSL_free(values);
1504 return ret;
1505 }
1506
1507 /* Serialization of keys. */
1508
DILITHIUM_marshal_public_key(CBB * out,const struct DILITHIUM_public_key * public_key)1509 int DILITHIUM_marshal_public_key(
1510 CBB *out, const struct DILITHIUM_public_key *public_key) {
1511 return dilithium_marshal_public_key(out,
1512 public_key_from_external(public_key));
1513 }
1514
DILITHIUM_parse_public_key(struct DILITHIUM_public_key * public_key,CBS * in)1515 int DILITHIUM_parse_public_key(struct DILITHIUM_public_key *public_key,
1516 CBS *in) {
1517 struct public_key *pub = public_key_from_external(public_key);
1518 CBS orig_in = *in;
1519 if (!dilithium_parse_public_key(pub, in) || CBS_len(in) != 0) {
1520 return 0;
1521 }
1522
1523 // Compute pre-cached values.
1524 BORINGSSL_keccak(pub->public_key_hash, sizeof(pub->public_key_hash),
1525 CBS_data(&orig_in), CBS_len(&orig_in), boringssl_shake256);
1526 return 1;
1527 }
1528
DILITHIUM_marshal_private_key(CBB * out,const struct DILITHIUM_private_key * private_key)1529 int DILITHIUM_marshal_private_key(
1530 CBB *out, const struct DILITHIUM_private_key *private_key) {
1531 return dilithium_marshal_private_key(out,
1532 private_key_from_external(private_key));
1533 }
1534
DILITHIUM_parse_private_key(struct DILITHIUM_private_key * private_key,CBS * in)1535 int DILITHIUM_parse_private_key(struct DILITHIUM_private_key *private_key,
1536 CBS *in) {
1537 struct private_key *priv = private_key_from_external(private_key);
1538 return dilithium_parse_private_key(priv, in) && CBS_len(in) == 0;
1539 }
1540