xref: /aosp_15_r20/external/boringssl/src/crypto/dilithium/dilithium.c (revision 8fb009dc861624b67b6cdb62ea21f0f22d0c584b)
1 /* Copyright (c) 2023, Google LLC
2  *
3  * Permission to use, copy, modify, and/or distribute this software for any
4  * purpose with or without fee is hereby granted, provided that the above
5  * copyright notice and this permission notice appear in all copies.
6  *
7  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10  * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12  * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13  * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
14 
15 #define OPENSSL_UNSTABLE_EXPERIMENTAL_DILITHIUM
16 #include <openssl/experimental/dilithium.h>
17 
18 #include <assert.h>
19 #include <stdlib.h>
20 
21 #include <openssl/bytestring.h>
22 #include <openssl/rand.h>
23 
24 #include "../internal.h"
25 #include "../keccak/internal.h"
26 #include "./internal.h"
27 
28 #define DEGREE 256
29 #define K 6
30 #define L 5
31 #define ETA 4
32 #define TAU 49
33 #define BETA 196
34 #define OMEGA 55
35 
36 #define RHO_BYTES 32
37 #define SIGMA_BYTES 64
38 #define K_BYTES 32
39 #define TR_BYTES 64
40 #define MU_BYTES 64
41 #define RHO_PRIME_BYTES 64
42 #define LAMBDA_BITS 192
43 #define LAMBDA_BYTES (LAMBDA_BITS / 8)
44 
45 // 2^23 - 2^13 + 1
46 static const uint32_t kPrime = 8380417;
47 // Inverse of -kPrime modulo 2^32
48 static const uint32_t kPrimeNegInverse = 4236238847;
49 static const int kDroppedBits = 13;
50 static const uint32_t kHalfPrime = (8380417 - 1) / 2;
51 static const uint32_t kGamma1 = 1 << 19;
52 static const uint32_t kGamma2 = (8380417 - 1) / 32;
53 // 256^-1 mod kPrime, in Montgomery form.
54 static const uint32_t kInverseDegreeMontgomery = 41978;
55 
56 typedef struct scalar {
57   uint32_t c[DEGREE];
58 } scalar;
59 
60 typedef struct vectork {
61   scalar v[K];
62 } vectork;
63 
64 typedef struct vectorl {
65   scalar v[L];
66 } vectorl;
67 
68 typedef struct matrix {
69   scalar v[K][L];
70 } matrix;
71 
72 /* Arithmetic */
73 
74 // This bit of Python will be referenced in some of the following comments:
75 //
76 // q = 8380417
77 // # Inverse of -q modulo 2^32
78 // q_neg_inverse = 4236238847
79 // # 2^64 modulo q
80 // montgomery_square = 2365951
81 //
82 // def bitreverse(i):
83 //     ret = 0
84 //     for n in range(8):
85 //         bit = i & 1
86 //         ret <<= 1
87 //         ret |= bit
88 //         i >>= 1
89 //     return ret
90 //
91 // def montgomery_reduce(x):
92 //     a = (x * q_neg_inverse) % 2**32
93 //     b = x + a * q
94 //     assert b & 0xFFFF_FFFF == 0
95 //     c = b >> 32
96 //     assert c < q
97 //     return c
98 //
99 // def montgomery_transform(x):
100 //     return montgomery_reduce(x * montgomery_square)
101 
102 // kNTTRootsMontgomery = [
103 //   montgomery_transform(pow(1753, bitreverse(i), q)) for i in range(256)
104 // ]
105 static const uint32_t kNTTRootsMontgomery[256] = {
106     4193792, 25847,   5771523, 7861508, 237124,  7602457, 7504169, 466468,
107     1826347, 2353451, 8021166, 6288512, 3119733, 5495562, 3111497, 2680103,
108     2725464, 1024112, 7300517, 3585928, 7830929, 7260833, 2619752, 6271868,
109     6262231, 4520680, 6980856, 5102745, 1757237, 8360995, 4010497, 280005,
110     2706023, 95776,   3077325, 3530437, 6718724, 4788269, 5842901, 3915439,
111     4519302, 5336701, 3574422, 5512770, 3539968, 8079950, 2348700, 7841118,
112     6681150, 6736599, 3505694, 4558682, 3507263, 6239768, 6779997, 3699596,
113     811944,  531354,  954230,  3881043, 3900724, 5823537, 2071892, 5582638,
114     4450022, 6851714, 4702672, 5339162, 6927966, 3475950, 2176455, 6795196,
115     7122806, 1939314, 4296819, 7380215, 5190273, 5223087, 4747489, 126922,
116     3412210, 7396998, 2147896, 2715295, 5412772, 4686924, 7969390, 5903370,
117     7709315, 7151892, 8357436, 7072248, 7998430, 1349076, 1852771, 6949987,
118     5037034, 264944,  508951,  3097992, 44288,   7280319, 904516,  3958618,
119     4656075, 8371839, 1653064, 5130689, 2389356, 8169440, 759969,  7063561,
120     189548,  4827145, 3159746, 6529015, 5971092, 8202977, 1315589, 1341330,
121     1285669, 6795489, 7567685, 6940675, 5361315, 4499357, 4751448, 3839961,
122     2091667, 3407706, 2316500, 3817976, 5037939, 2244091, 5933984, 4817955,
123     266997,  2434439, 7144689, 3513181, 4860065, 4621053, 7183191, 5187039,
124     900702,  1859098, 909542,  819034,  495491,  6767243, 8337157, 7857917,
125     7725090, 5257975, 2031748, 3207046, 4823422, 7855319, 7611795, 4784579,
126     342297,  286988,  5942594, 4108315, 3437287, 5038140, 1735879, 203044,
127     2842341, 2691481, 5790267, 1265009, 4055324, 1247620, 2486353, 1595974,
128     4613401, 1250494, 2635921, 4832145, 5386378, 1869119, 1903435, 7329447,
129     7047359, 1237275, 5062207, 6950192, 7929317, 1312455, 3306115, 6417775,
130     7100756, 1917081, 5834105, 7005614, 1500165, 777191,  2235880, 3406031,
131     7838005, 5548557, 6709241, 6533464, 5796124, 4656147, 594136,  4603424,
132     6366809, 2432395, 2454455, 8215696, 1957272, 3369112, 185531,  7173032,
133     5196991, 162844,  1616392, 3014001, 810149,  1652634, 4686184, 6581310,
134     5341501, 3523897, 3866901, 269760,  2213111, 7404533, 1717735, 472078,
135     7953734, 1723600, 6577327, 1910376, 6712985, 7276084, 8119771, 4546524,
136     5441381, 6144432, 7959518, 6094090, 183443,  7403526, 1612842, 4834730,
137     7826001, 3919660, 8332111, 7018208, 3937738, 1400424, 7534263, 1976782};
138 
139 // Reduces x mod kPrime in constant time, where 0 <= x < 2*kPrime.
reduce_once(uint32_t x)140 static uint32_t reduce_once(uint32_t x) {
141   declassify_assert(x < 2 * kPrime);
142   // return x < kPrime ? x : x - kPrime;
143   return constant_time_select_int(constant_time_lt_w(x, kPrime), x, x - kPrime);
144 }
145 
146 // Returns the absolute value in constant time.
abs_signed(uint32_t x)147 static uint32_t abs_signed(uint32_t x) {
148   // return is_positive(x) ? x : -x;
149   // Note: MSVC doesn't like applying the unary minus operator to unsigned types
150   // (warning C4146), so we write the negation as a bitwise not plus one
151   // (assuming two's complement representation).
152   return constant_time_select_int(constant_time_lt_w(x, 0x80000000), x, ~x + 1);
153 }
154 
155 // Returns the absolute value modulo kPrime.
abs_mod_prime(uint32_t x)156 static uint32_t abs_mod_prime(uint32_t x) {
157   declassify_assert(x < kPrime);
158   // return x > kHalfPrime ? kPrime - x : x;
159   return constant_time_select_int(constant_time_lt_w(kHalfPrime, x), kPrime - x,
160                                   x);
161 }
162 
163 // Returns the maximum of two values in constant time.
maximum(uint32_t x,uint32_t y)164 static uint32_t maximum(uint32_t x, uint32_t y) {
165   // return x < y ? y : x;
166   return constant_time_select_int(constant_time_lt_w(x, y), y, x);
167 }
168 
scalar_add(scalar * out,const scalar * lhs,const scalar * rhs)169 static void scalar_add(scalar *out, const scalar *lhs, const scalar *rhs) {
170   for (int i = 0; i < DEGREE; i++) {
171     out->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
172   }
173 }
174 
scalar_sub(scalar * out,const scalar * lhs,const scalar * rhs)175 static void scalar_sub(scalar *out, const scalar *lhs, const scalar *rhs) {
176   for (int i = 0; i < DEGREE; i++) {
177     out->c[i] = reduce_once(kPrime + lhs->c[i] - rhs->c[i]);
178   }
179 }
180 
reduce_montgomery(uint64_t x)181 static uint32_t reduce_montgomery(uint64_t x) {
182   uint64_t a = (uint32_t)x * kPrimeNegInverse;
183   uint64_t b = x + a * kPrime;
184   declassify_assert((b & 0xffffffff) == 0);
185   uint32_t c = b >> 32;
186   return reduce_once(c);
187 }
188 
189 // Multiply two scalars in the number theoretically transformed state.
scalar_mult(scalar * out,const scalar * lhs,const scalar * rhs)190 static void scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs) {
191   for (int i = 0; i < DEGREE; i++) {
192     out->c[i] = reduce_montgomery((uint64_t)lhs->c[i] * (uint64_t)rhs->c[i]);
193   }
194 }
195 
196 // In place number theoretic transform of a given scalar.
197 //
198 // FIPS 204, Algorithm 35 (`NTT`).
scalar_ntt(scalar * s)199 static void scalar_ntt(scalar *s) {
200   // Step: 1, 2, 4, 8, ..., 128
201   // Offset: 128, 64, 32, 16, ..., 1
202   int offset = DEGREE;
203   for (int step = 1; step < DEGREE; step <<= 1) {
204     offset >>= 1;
205     int k = 0;
206     for (int i = 0; i < step; i++) {
207       assert(k == 2 * offset * i);
208       const uint32_t step_root = kNTTRootsMontgomery[step + i];
209       for (int j = k; j < k + offset; j++) {
210         uint32_t even = s->c[j];
211         uint32_t odd =
212             reduce_montgomery((uint64_t)step_root * (uint64_t)s->c[j + offset]);
213         s->c[j] = reduce_once(odd + even);
214         s->c[j + offset] = reduce_once(kPrime + even - odd);
215       }
216       k += 2 * offset;
217     }
218   }
219 }
220 
221 // In place inverse number theoretic transform of a given scalar.
222 //
223 // FIPS 204, Algorithm 36 (`NTT^-1`).
scalar_inverse_ntt(scalar * s)224 static void scalar_inverse_ntt(scalar *s) {
225   // Step: 128, 64, 32, 16, ..., 1
226   // Offset: 1, 2, 4, 8, ..., 128
227   int step = DEGREE;
228   for (int offset = 1; offset < DEGREE; offset <<= 1) {
229     step >>= 1;
230     int k = 0;
231     for (int i = 0; i < step; i++) {
232       assert(k == 2 * offset * i);
233       const uint32_t step_root =
234           kPrime - kNTTRootsMontgomery[step + (step - 1 - i)];
235       for (int j = k; j < k + offset; j++) {
236         uint32_t even = s->c[j];
237         uint32_t odd = s->c[j + offset];
238         s->c[j] = reduce_once(odd + even);
239         s->c[j + offset] = reduce_montgomery((uint64_t)step_root *
240                                              (uint64_t)(kPrime + even - odd));
241       }
242       k += 2 * offset;
243     }
244   }
245   for (int i = 0; i < DEGREE; i++) {
246     s->c[i] = reduce_montgomery((uint64_t)s->c[i] *
247                                 (uint64_t)kInverseDegreeMontgomery);
248   }
249 }
250 
vectork_zero(vectork * out)251 static void vectork_zero(vectork *out) { OPENSSL_memset(out, 0, sizeof(*out)); }
252 
vectork_add(vectork * out,const vectork * lhs,const vectork * rhs)253 static void vectork_add(vectork *out, const vectork *lhs, const vectork *rhs) {
254   for (int i = 0; i < K; i++) {
255     scalar_add(&out->v[i], &lhs->v[i], &rhs->v[i]);
256   }
257 }
258 
vectork_sub(vectork * out,const vectork * lhs,const vectork * rhs)259 static void vectork_sub(vectork *out, const vectork *lhs, const vectork *rhs) {
260   for (int i = 0; i < K; i++) {
261     scalar_sub(&out->v[i], &lhs->v[i], &rhs->v[i]);
262   }
263 }
264 
vectork_mult_scalar(vectork * out,const vectork * lhs,const scalar * rhs)265 static void vectork_mult_scalar(vectork *out, const vectork *lhs,
266                                 const scalar *rhs) {
267   for (int i = 0; i < K; i++) {
268     scalar_mult(&out->v[i], &lhs->v[i], rhs);
269   }
270 }
271 
vectork_ntt(vectork * a)272 static void vectork_ntt(vectork *a) {
273   for (int i = 0; i < K; i++) {
274     scalar_ntt(&a->v[i]);
275   }
276 }
277 
vectork_inverse_ntt(vectork * a)278 static void vectork_inverse_ntt(vectork *a) {
279   for (int i = 0; i < K; i++) {
280     scalar_inverse_ntt(&a->v[i]);
281   }
282 }
283 
vectorl_add(vectorl * out,const vectorl * lhs,const vectorl * rhs)284 static void vectorl_add(vectorl *out, const vectorl *lhs, const vectorl *rhs) {
285   for (int i = 0; i < L; i++) {
286     scalar_add(&out->v[i], &lhs->v[i], &rhs->v[i]);
287   }
288 }
289 
vectorl_mult_scalar(vectorl * out,const vectorl * lhs,const scalar * rhs)290 static void vectorl_mult_scalar(vectorl *out, const vectorl *lhs,
291                                 const scalar *rhs) {
292   for (int i = 0; i < L; i++) {
293     scalar_mult(&out->v[i], &lhs->v[i], rhs);
294   }
295 }
296 
vectorl_ntt(vectorl * a)297 static void vectorl_ntt(vectorl *a) {
298   for (int i = 0; i < L; i++) {
299     scalar_ntt(&a->v[i]);
300   }
301 }
302 
vectorl_inverse_ntt(vectorl * a)303 static void vectorl_inverse_ntt(vectorl *a) {
304   for (int i = 0; i < L; i++) {
305     scalar_inverse_ntt(&a->v[i]);
306   }
307 }
308 
matrix_mult(vectork * out,const matrix * m,const vectorl * a)309 static void matrix_mult(vectork *out, const matrix *m, const vectorl *a) {
310   vectork_zero(out);
311   for (int i = 0; i < K; i++) {
312     for (int j = 0; j < L; j++) {
313       scalar product;
314       scalar_mult(&product, &m->v[i][j], &a->v[j]);
315       scalar_add(&out->v[i], &out->v[i], &product);
316     }
317   }
318 }
319 
320 /* Rounding & hints */
321 
322 // FIPS 204, Algorithm 29 (`Power2Round`).
power2_round(uint32_t * r1,uint32_t * r0,uint32_t r)323 static void power2_round(uint32_t *r1, uint32_t *r0, uint32_t r) {
324   *r1 = r >> kDroppedBits;
325   *r0 = r - (*r1 << kDroppedBits);
326 
327   uint32_t r0_adjusted = reduce_once(kPrime + *r0 - (1 << kDroppedBits));
328   uint32_t r1_adjusted = *r1 + 1;
329 
330   // Mask is set iff r0 > 2^(dropped_bits - 1).
331   crypto_word_t mask =
332       constant_time_lt_w((uint32_t)(1 << (kDroppedBits - 1)), *r0);
333   // r0 = mask ? r0_adjusted : r0
334   *r0 = constant_time_select_int(mask, r0_adjusted, *r0);
335   // r1 = mask ? r1_adjusted : r1
336   *r1 = constant_time_select_int(mask, r1_adjusted, *r1);
337 }
338 
339 // Scale back previously rounded value.
scale_power2_round(uint32_t * out,uint32_t r1)340 static void scale_power2_round(uint32_t *out, uint32_t r1) {
341   // Pre-condition: 0 <= r1 <= 2^10 - 1
342   *out = r1 << kDroppedBits;
343   // Post-condition: 0 <= out <= 2^23 - 2^13 = kPrime - 1
344   assert(*out < kPrime);
345 }
346 
347 // FIPS 204, Algorithm 31 (`HighBits`).
high_bits(uint32_t x)348 static uint32_t high_bits(uint32_t x) {
349   // Reference description (given 0 <= x < q):
350   //
351   // ```
352   // int32_t r0 = x mod+- (2 * kGamma2);
353   // if (x - r0 == q - 1) {
354   //   return 0;
355   // } else {
356   //   return (x - r0) / (2 * kGamma2);
357   // }
358   // ```
359   //
360   // Below is the formula taken from the reference implementation.
361   //
362   // Here, kGamma2 == 2^18 - 2^8
363   // This returns ((ceil(x / 2^7) * (2^10 + 1) + 2^21) / 2^22) mod 2^4
364   uint32_t r1 = (x + 127) >> 7;
365   r1 = (r1 * 1025 + (1 << 21)) >> 22;
366   r1 &= 15;
367   return r1;
368 }
369 
370 // FIPS 204, Algorithm 30 (`Decompose`).
decompose(uint32_t * r1,int32_t * r0,uint32_t r)371 static void decompose(uint32_t *r1, int32_t *r0, uint32_t r) {
372   *r1 = high_bits(r);
373 
374   *r0 = r;
375   *r0 -= *r1 * 2 * (int32_t)kGamma2;
376   *r0 -= (((int32_t)kHalfPrime - *r0) >> 31) & (int32_t)kPrime;
377 }
378 
379 // FIPS 204, Algorithm 32 (`LowBits`).
low_bits(uint32_t x)380 static int32_t low_bits(uint32_t x) {
381   uint32_t r1;
382   int32_t r0;
383   decompose(&r1, &r0, x);
384   return r0;
385 }
386 
387 // FIPS 204, Algorithm 33 (`MakeHint`).
make_hint(uint32_t ct0,uint32_t cs2,uint32_t w)388 static int32_t make_hint(uint32_t ct0, uint32_t cs2, uint32_t w) {
389   uint32_t r_plus_z = reduce_once(kPrime + w - cs2);
390   uint32_t r = reduce_once(r_plus_z + ct0);
391   return high_bits(r) != high_bits(r_plus_z);
392 }
393 
394 // FIPS 204, Algorithm 34 (`UseHint`).
use_hint_vartime(uint32_t h,uint32_t r)395 static uint32_t use_hint_vartime(uint32_t h, uint32_t r) {
396   uint32_t r1;
397   int32_t r0;
398   decompose(&r1, &r0, r);
399 
400   if (h) {
401     if (r0 > 0) {
402       return (r1 + 1) & 15;
403     } else {
404       return (r1 - 1) & 15;
405     }
406   } else {
407     return r1;
408   }
409 }
410 
scalar_power2_round(scalar * s1,scalar * s0,const scalar * s)411 static void scalar_power2_round(scalar *s1, scalar *s0, const scalar *s) {
412   for (int i = 0; i < DEGREE; i++) {
413     power2_round(&s1->c[i], &s0->c[i], s->c[i]);
414   }
415 }
416 
scalar_scale_power2_round(scalar * out,const scalar * in)417 static void scalar_scale_power2_round(scalar *out, const scalar *in) {
418   for (int i = 0; i < DEGREE; i++) {
419     scale_power2_round(&out->c[i], in->c[i]);
420   }
421 }
422 
scalar_high_bits(scalar * out,const scalar * in)423 static void scalar_high_bits(scalar *out, const scalar *in) {
424   for (int i = 0; i < DEGREE; i++) {
425     out->c[i] = high_bits(in->c[i]);
426   }
427 }
428 
scalar_low_bits(scalar * out,const scalar * in)429 static void scalar_low_bits(scalar *out, const scalar *in) {
430   for (int i = 0; i < DEGREE; i++) {
431     out->c[i] = low_bits(in->c[i]);
432   }
433 }
434 
scalar_max(uint32_t * max,const scalar * s)435 static void scalar_max(uint32_t *max, const scalar *s) {
436   for (int i = 0; i < DEGREE; i++) {
437     uint32_t abs = abs_mod_prime(s->c[i]);
438     *max = maximum(*max, abs);
439   }
440 }
441 
scalar_max_signed(uint32_t * max,const scalar * s)442 static void scalar_max_signed(uint32_t *max, const scalar *s) {
443   for (int i = 0; i < DEGREE; i++) {
444     uint32_t abs = abs_signed(s->c[i]);
445     *max = maximum(*max, abs);
446   }
447 }
448 
scalar_make_hint(scalar * out,const scalar * ct0,const scalar * cs2,const scalar * w)449 static void scalar_make_hint(scalar *out, const scalar *ct0, const scalar *cs2,
450                              const scalar *w) {
451   for (int i = 0; i < DEGREE; i++) {
452     out->c[i] = make_hint(ct0->c[i], cs2->c[i], w->c[i]);
453   }
454 }
455 
scalar_use_hint_vartime(scalar * out,const scalar * h,const scalar * r)456 static void scalar_use_hint_vartime(scalar *out, const scalar *h,
457                                     const scalar *r) {
458   for (int i = 0; i < DEGREE; i++) {
459     out->c[i] = use_hint_vartime(h->c[i], r->c[i]);
460   }
461 }
462 
vectork_power2_round(vectork * t1,vectork * t0,const vectork * t)463 static void vectork_power2_round(vectork *t1, vectork *t0, const vectork *t) {
464   for (int i = 0; i < K; i++) {
465     scalar_power2_round(&t1->v[i], &t0->v[i], &t->v[i]);
466   }
467 }
468 
vectork_scale_power2_round(vectork * out,const vectork * in)469 static void vectork_scale_power2_round(vectork *out, const vectork *in) {
470   for (int i = 0; i < K; i++) {
471     scalar_scale_power2_round(&out->v[i], &in->v[i]);
472   }
473 }
474 
vectork_high_bits(vectork * out,const vectork * in)475 static void vectork_high_bits(vectork *out, const vectork *in) {
476   for (int i = 0; i < K; i++) {
477     scalar_high_bits(&out->v[i], &in->v[i]);
478   }
479 }
480 
vectork_low_bits(vectork * out,const vectork * in)481 static void vectork_low_bits(vectork *out, const vectork *in) {
482   for (int i = 0; i < K; i++) {
483     scalar_low_bits(&out->v[i], &in->v[i]);
484   }
485 }
486 
vectork_max(const vectork * a)487 static uint32_t vectork_max(const vectork *a) {
488   uint32_t max = 0;
489   for (int i = 0; i < K; i++) {
490     scalar_max(&max, &a->v[i]);
491   }
492   return max;
493 }
494 
vectork_max_signed(const vectork * a)495 static uint32_t vectork_max_signed(const vectork *a) {
496   uint32_t max = 0;
497   for (int i = 0; i < K; i++) {
498     scalar_max_signed(&max, &a->v[i]);
499   }
500   return max;
501 }
502 
503 // The input vector contains only zeroes and ones.
vectork_count_ones(const vectork * a)504 static size_t vectork_count_ones(const vectork *a) {
505   size_t count = 0;
506   for (int i = 0; i < K; i++) {
507     for (int j = 0; j < DEGREE; j++) {
508       count += a->v[i].c[j];
509     }
510   }
511   return count;
512 }
513 
vectork_make_hint(vectork * out,const vectork * ct0,const vectork * cs2,const vectork * w)514 static void vectork_make_hint(vectork *out, const vectork *ct0,
515                               const vectork *cs2, const vectork *w) {
516   for (int i = 0; i < K; i++) {
517     scalar_make_hint(&out->v[i], &ct0->v[i], &cs2->v[i], &w->v[i]);
518   }
519 }
520 
vectork_use_hint_vartime(vectork * out,const vectork * h,const vectork * r)521 static void vectork_use_hint_vartime(vectork *out, const vectork *h,
522                                      const vectork *r) {
523   for (int i = 0; i < K; i++) {
524     scalar_use_hint_vartime(&out->v[i], &h->v[i], &r->v[i]);
525   }
526 }
527 
vectorl_max(const vectorl * a)528 static uint32_t vectorl_max(const vectorl *a) {
529   uint32_t max = 0;
530   for (int i = 0; i < L; i++) {
531     scalar_max(&max, &a->v[i]);
532   }
533   return max;
534 }
535 
536 /* Bit packing */
537 
538 static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
539                                   0x1f, 0x3f, 0x7f, 0xff};
540 
541 // FIPS 204, Algorithm 10 (`SimpleBitPack`).
scalar_encode(uint8_t * out,const scalar * s,int bits)542 static void scalar_encode(uint8_t *out, const scalar *s, int bits) {
543   assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
544 
545   uint8_t out_byte = 0;
546   int out_byte_bits = 0;
547 
548   for (int i = 0; i < DEGREE; i++) {
549     uint32_t element = s->c[i];
550     int element_bits_done = 0;
551 
552     while (element_bits_done < bits) {
553       int chunk_bits = bits - element_bits_done;
554       int out_bits_remaining = 8 - out_byte_bits;
555       if (chunk_bits >= out_bits_remaining) {
556         chunk_bits = out_bits_remaining;
557         out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
558         *out = out_byte;
559         out++;
560         out_byte_bits = 0;
561         out_byte = 0;
562       } else {
563         out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
564         out_byte_bits += chunk_bits;
565       }
566 
567       element_bits_done += chunk_bits;
568       element >>= chunk_bits;
569     }
570   }
571 
572   if (out_byte_bits > 0) {
573     *out = out_byte;
574   }
575 }
576 
577 // FIPS 204, Algorithm 11 (`BitPack`).
scalar_encode_signed(uint8_t * out,const scalar * s,int bits,uint32_t max)578 static void scalar_encode_signed(uint8_t *out, const scalar *s, int bits,
579                                  uint32_t max) {
580   assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
581 
582   uint8_t out_byte = 0;
583   int out_byte_bits = 0;
584 
585   for (int i = 0; i < DEGREE; i++) {
586     uint32_t element = reduce_once(kPrime + max - s->c[i]);
587     declassify_assert(element <= 2 * max);
588     int element_bits_done = 0;
589 
590     while (element_bits_done < bits) {
591       int chunk_bits = bits - element_bits_done;
592       int out_bits_remaining = 8 - out_byte_bits;
593       if (chunk_bits >= out_bits_remaining) {
594         chunk_bits = out_bits_remaining;
595         out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
596         *out = out_byte;
597         out++;
598         out_byte_bits = 0;
599         out_byte = 0;
600       } else {
601         out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
602         out_byte_bits += chunk_bits;
603       }
604 
605       element_bits_done += chunk_bits;
606       element >>= chunk_bits;
607     }
608   }
609 
610   if (out_byte_bits > 0) {
611     *out = out_byte;
612   }
613 }
614 
615 // FIPS 204, Algorithm 12 (`SimpleBitUnpack`).
scalar_decode(scalar * out,const uint8_t * in,int bits)616 static void scalar_decode(scalar *out, const uint8_t *in, int bits) {
617   assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
618 
619   uint8_t in_byte = 0;
620   int in_byte_bits_left = 0;
621 
622   for (int i = 0; i < DEGREE; i++) {
623     uint32_t element = 0;
624     int element_bits_done = 0;
625 
626     while (element_bits_done < bits) {
627       if (in_byte_bits_left == 0) {
628         in_byte = *in;
629         in++;
630         in_byte_bits_left = 8;
631       }
632 
633       int chunk_bits = bits - element_bits_done;
634       if (chunk_bits > in_byte_bits_left) {
635         chunk_bits = in_byte_bits_left;
636       }
637 
638       element |= (in_byte & kMasks[chunk_bits - 1]) << element_bits_done;
639       in_byte_bits_left -= chunk_bits;
640       in_byte >>= chunk_bits;
641 
642       element_bits_done += chunk_bits;
643     }
644 
645     out->c[i] = element;
646   }
647 }
648 
649 // FIPS 204, Algorithm 13 (`BitUnpack`).
scalar_decode_signed(scalar * out,const uint8_t * in,int bits,uint32_t max)650 static int scalar_decode_signed(scalar *out, const uint8_t *in, int bits,
651                                 uint32_t max) {
652   assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
653 
654   uint8_t in_byte = 0;
655   int in_byte_bits_left = 0;
656 
657   for (int i = 0; i < DEGREE; i++) {
658     uint32_t element = 0;
659     int element_bits_done = 0;
660 
661     while (element_bits_done < bits) {
662       if (in_byte_bits_left == 0) {
663         in_byte = *in;
664         in++;
665         in_byte_bits_left = 8;
666       }
667 
668       int chunk_bits = bits - element_bits_done;
669       if (chunk_bits > in_byte_bits_left) {
670         chunk_bits = in_byte_bits_left;
671       }
672 
673       element |= (in_byte & kMasks[chunk_bits - 1]) << element_bits_done;
674       in_byte_bits_left -= chunk_bits;
675       in_byte >>= chunk_bits;
676 
677       element_bits_done += chunk_bits;
678     }
679 
680     // This may be only out of range in cases of invalid input, in which case it
681     // is okay to leak the value. This function is also called with secret
682     // input during signing, in |scalar_sample_mask|. However, in that case
683     // (and in any case when |max| is a power of two), this case is impossible.
684     if (constant_time_declassify_int(element > 2 * max)) {
685       return 0;
686     }
687     out->c[i] = reduce_once(kPrime + max - element);
688   }
689 
690   return 1;
691 }
692 
693 /* Expansion functions */
694 
695 // FIPS 204, Algorithm 24 (`RejNTTPoly`).
696 //
697 // Rejection samples a Keccak stream to get uniformly distributed elements. This
698 // is used for matrix expansion and only operates on public inputs.
scalar_from_keccak_vartime(scalar * out,const uint8_t derived_seed[RHO_BYTES+2])699 static void scalar_from_keccak_vartime(
700     scalar *out, const uint8_t derived_seed[RHO_BYTES + 2]) {
701   struct BORINGSSL_keccak_st keccak_ctx;
702   BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake128);
703   BORINGSSL_keccak_absorb(&keccak_ctx, derived_seed, RHO_BYTES + 2);
704   assert(keccak_ctx.squeeze_offset == 0);
705   assert(keccak_ctx.rate_bytes == 168);
706   static_assert(168 % 3 == 0, "block and coefficient boundaries do not align");
707 
708   int done = 0;
709   while (done < DEGREE) {
710     uint8_t block[168];
711     BORINGSSL_keccak_squeeze(&keccak_ctx, block, sizeof(block));
712     for (size_t i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
713       // FIPS 204, Algorithm 8 (`CoeffFromThreeBytes`).
714       uint32_t value = (uint32_t)block[i] | ((uint32_t)block[i + 1] << 8) |
715                        (((uint32_t)block[i + 2] & 0x7f) << 16);
716       if (value < kPrime) {
717         out->c[done++] = value;
718       }
719     }
720   }
721 }
722 
723 // FIPS 204, Algorithm 25 (`RejBoundedPoly`).
scalar_uniform_eta_4(scalar * out,const uint8_t derived_seed[SIGMA_BYTES+2])724 static void scalar_uniform_eta_4(scalar *out,
725                                  const uint8_t derived_seed[SIGMA_BYTES + 2]) {
726   static_assert(ETA == 4, "This implementation is specialized for ETA == 4");
727 
728   struct BORINGSSL_keccak_st keccak_ctx;
729   BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
730   BORINGSSL_keccak_absorb(&keccak_ctx, derived_seed, SIGMA_BYTES + 2);
731   assert(keccak_ctx.squeeze_offset == 0);
732   assert(keccak_ctx.rate_bytes == 136);
733 
734   int done = 0;
735   while (done < DEGREE) {
736     uint8_t block[136];
737     BORINGSSL_keccak_squeeze(&keccak_ctx, block, sizeof(block));
738     for (size_t i = 0; i < sizeof(block) && done < DEGREE; ++i) {
739       uint32_t t0 = block[i] & 0x0F;
740       uint32_t t1 = block[i] >> 4;
741       // FIPS 204, Algorithm 9 (`CoefFromHalfByte`). Although both the input and
742       // output here are secret, it is OK to leak when we rejected a byte.
743       // Individual bytes of the SHAKE-256 stream are (indistiguishable from)
744       // independent of each other and the original seed, so leaking information
745       // about the rejected bytes does not reveal the input or output.
746       if (constant_time_declassify_int(t0 < 9)) {
747         out->c[done++] = reduce_once(kPrime + ETA - t0);
748       }
749       if (done < DEGREE && constant_time_declassify_int(t1 < 9)) {
750         out->c[done++] = reduce_once(kPrime + ETA - t1);
751       }
752     }
753   }
754 }
755 
756 // FIPS 204, Algorithm 28 (`ExpandMask`).
scalar_sample_mask(scalar * out,const uint8_t derived_seed[RHO_PRIME_BYTES+2])757 static void scalar_sample_mask(
758     scalar *out, const uint8_t derived_seed[RHO_PRIME_BYTES + 2]) {
759   uint8_t buf[640];
760   BORINGSSL_keccak(buf, sizeof(buf), derived_seed, RHO_PRIME_BYTES + 2,
761                    boringssl_shake256);
762 
763   // Note: Decoding 20 bits into (-2^19, 2^19] cannot fail.
764   scalar_decode_signed(out, buf, 20, 1 << 19);
765 }
766 
767 // FIPS 204, Algorithm 23 (`SampleInBall`).
scalar_sample_in_ball_vartime(scalar * out,const uint8_t * seed,int len)768 static void scalar_sample_in_ball_vartime(scalar *out, const uint8_t *seed,
769                                           int len) {
770   assert(len == 32);
771 
772   struct BORINGSSL_keccak_st keccak_ctx;
773   BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
774   BORINGSSL_keccak_absorb(&keccak_ctx, seed, len);
775   assert(keccak_ctx.squeeze_offset == 0);
776   assert(keccak_ctx.rate_bytes == 136);
777 
778   uint8_t block[136];
779   BORINGSSL_keccak_squeeze(&keccak_ctx, block, sizeof(block));
780 
781   uint64_t signs = CRYPTO_load_u64_le(block);
782   int offset = 8;
783   // SampleInBall implements a Fisher–Yates shuffle, which unavoidably leaks
784   // where the zeros are by memory access pattern. Although this leak happens
785   // before bad signatures are rejected, this is safe. See
786   // https://boringssl-review.googlesource.com/c/boringssl/+/67747/comment/8d8f01ac_70af3f21/
787   CONSTTIME_DECLASSIFY(block + offset, sizeof(block) - offset);
788 
789   OPENSSL_memset(out, 0, sizeof(*out));
790   for (size_t i = DEGREE - TAU; i < DEGREE; i++) {
791     size_t byte;
792     for (;;) {
793       if (offset == 136) {
794         BORINGSSL_keccak_squeeze(&keccak_ctx, block, sizeof(block));
795         // See above.
796         CONSTTIME_DECLASSIFY(block, sizeof(block));
797         offset = 0;
798       }
799 
800       byte = block[offset++];
801       if (byte <= i) {
802         break;
803       }
804     }
805 
806     out->c[i] = out->c[byte];
807     out->c[byte] = reduce_once(kPrime + 1 - 2 * (signs & 1));
808     signs >>= 1;
809   }
810 }
811 
812 // FIPS 204, Algorithm 26 (`ExpandA`).
matrix_expand(matrix * out,const uint8_t rho[RHO_BYTES])813 static void matrix_expand(matrix *out, const uint8_t rho[RHO_BYTES]) {
814   static_assert(K <= 0x100, "K must fit in 8 bits");
815   static_assert(L <= 0x100, "L must fit in 8 bits");
816 
817   uint8_t derived_seed[RHO_BYTES + 2];
818   OPENSSL_memcpy(derived_seed, rho, RHO_BYTES);
819   for (int i = 0; i < K; i++) {
820     for (int j = 0; j < L; j++) {
821       derived_seed[RHO_BYTES + 1] = i;
822       derived_seed[RHO_BYTES] = j;
823       scalar_from_keccak_vartime(&out->v[i][j], derived_seed);
824     }
825   }
826 }
827 
828 // FIPS 204, Algorithm 27 (`ExpandS`).
vector_expand_short(vectorl * s1,vectork * s2,const uint8_t sigma[SIGMA_BYTES])829 static void vector_expand_short(vectorl *s1, vectork *s2,
830                                 const uint8_t sigma[SIGMA_BYTES]) {
831   static_assert(K <= 0x100, "K must fit in 8 bits");
832   static_assert(L <= 0x100, "L must fit in 8 bits");
833   static_assert(K + L <= 0x100, "K+L must fit in 8 bits");
834 
835   uint8_t derived_seed[SIGMA_BYTES + 2];
836   OPENSSL_memcpy(derived_seed, sigma, SIGMA_BYTES);
837   derived_seed[SIGMA_BYTES] = 0;
838   derived_seed[SIGMA_BYTES + 1] = 0;
839   for (int i = 0; i < L; i++) {
840     scalar_uniform_eta_4(&s1->v[i], derived_seed);
841     ++derived_seed[SIGMA_BYTES];
842   }
843   for (int i = 0; i < K; i++) {
844     scalar_uniform_eta_4(&s2->v[i], derived_seed);
845     ++derived_seed[SIGMA_BYTES];
846   }
847 }
848 
849 // FIPS 204, Algorithm 28 (`ExpandMask`).
vectorl_expand_mask(vectorl * out,const uint8_t seed[RHO_PRIME_BYTES],size_t kappa)850 static void vectorl_expand_mask(vectorl *out,
851                                 const uint8_t seed[RHO_PRIME_BYTES],
852                                 size_t kappa) {
853   assert(kappa + L <= 0x10000);
854 
855   uint8_t derived_seed[RHO_PRIME_BYTES + 2];
856   OPENSSL_memcpy(derived_seed, seed, RHO_PRIME_BYTES);
857   for (int i = 0; i < L; i++) {
858     size_t index = kappa + i;
859     derived_seed[RHO_PRIME_BYTES] = index & 0xFF;
860     derived_seed[RHO_PRIME_BYTES + 1] = (index >> 8) & 0xFF;
861     scalar_sample_mask(&out->v[i], derived_seed);
862   }
863 }
864 
865 /* Encoding */
866 
867 // FIPS 204, Algorithm 10 (`SimpleBitPack`).
868 //
869 // Encodes an entire vector into 32*K*|bits| bytes. Note that since 256 (DEGREE)
870 // is divisible by 8, the individual vector entries will always fill a whole
871 // number of bytes, so we do not need to worry about bit packing here.
vectork_encode(uint8_t * out,const vectork * a,int bits)872 static void vectork_encode(uint8_t *out, const vectork *a, int bits) {
873   for (int i = 0; i < K; i++) {
874     scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits);
875   }
876 }
877 
878 // FIPS 204, Algorithm 12 (`SimpleBitUnpack`).
vectork_decode(vectork * out,const uint8_t * in,int bits)879 static void vectork_decode(vectork *out, const uint8_t *in, int bits) {
880   for (int i = 0; i < K; i++) {
881     scalar_decode(&out->v[i], in + i * bits * DEGREE / 8, bits);
882   }
883 }
884 
vectork_encode_signed(uint8_t * out,const vectork * a,int bits,uint32_t max)885 static void vectork_encode_signed(uint8_t *out, const vectork *a, int bits,
886                                   uint32_t max) {
887   for (int i = 0; i < K; i++) {
888     scalar_encode_signed(out + i * bits * DEGREE / 8, &a->v[i], bits, max);
889   }
890 }
891 
vectork_decode_signed(vectork * out,const uint8_t * in,int bits,uint32_t max)892 static int vectork_decode_signed(vectork *out, const uint8_t *in, int bits,
893                                  uint32_t max) {
894   for (int i = 0; i < K; i++) {
895     if (!scalar_decode_signed(&out->v[i], in + i * bits * DEGREE / 8, bits,
896                               max)) {
897       return 0;
898     }
899   }
900   return 1;
901 }
902 
903 // FIPS 204, Algorithm 11 (`BitPack`).
904 //
905 // Encodes an entire vector into 32*L*|bits| bytes. Note that since 256 (DEGREE)
906 // is divisible by 8, the individual vector entries will always fill a whole
907 // number of bytes, so we do not need to worry about bit packing here.
vectorl_encode_signed(uint8_t * out,const vectorl * a,int bits,uint32_t max)908 static void vectorl_encode_signed(uint8_t *out, const vectorl *a, int bits,
909                                   uint32_t max) {
910   for (int i = 0; i < L; i++) {
911     scalar_encode_signed(out + i * bits * DEGREE / 8, &a->v[i], bits, max);
912   }
913 }
914 
vectorl_decode_signed(vectorl * out,const uint8_t * in,int bits,uint32_t max)915 static int vectorl_decode_signed(vectorl *out, const uint8_t *in, int bits,
916                                  uint32_t max) {
917   for (int i = 0; i < L; i++) {
918     if (!scalar_decode_signed(&out->v[i], in + i * bits * DEGREE / 8, bits,
919                               max)) {
920       return 0;
921     }
922   }
923   return 1;
924 }
925 
926 // FIPS 204, Algorithm 22 (`w1Encode`).
927 //
928 // The output must point to an array of 128*K bytes.
w1_encode(uint8_t * out,const vectork * w1)929 static void w1_encode(uint8_t *out, const vectork *w1) {
930   vectork_encode(out, w1, 4);
931 }
932 
933 // FIPS 204, Algorithm 14 (`HintBitPack`).
hint_bit_pack(uint8_t * out,const vectork * h)934 static void hint_bit_pack(uint8_t *out, const vectork *h) {
935   OPENSSL_memset(out, 0, OMEGA + K);
936   int index = 0;
937   for (int i = 0; i < K; i++) {
938     for (int j = 0; j < DEGREE; j++) {
939       if (h->v[i].c[j]) {
940         out[index++] = j;
941       }
942     }
943     out[OMEGA + i] = index;
944   }
945 }
946 
947 // FIPS 204, Algorithm 15 (`HintBitUnpack`).
hint_bit_unpack(vectork * h,const uint8_t * in)948 static int hint_bit_unpack(vectork *h, const uint8_t *in) {
949   vectork_zero(h);
950   int index = 0;
951   for (int i = 0; i < K; i++) {
952     int limit = in[OMEGA + i];
953     if (limit < index || limit > OMEGA) {
954       return 0;
955     }
956 
957     int last = -1;
958     while (index < limit) {
959       int byte = in[index++];
960       if (last >= 0 && byte <= last) {
961         return 0;
962       }
963       last = byte;
964       h->v[i].c[byte] = 1;
965     }
966   }
967   for (; index < OMEGA; index++) {
968     if (in[index] != 0) {
969       return 0;
970     }
971   }
972   return 1;
973 }
974 
975 struct public_key {
976   uint8_t rho[RHO_BYTES];
977   vectork t1;
978   // Pre-cached value(s).
979   uint8_t public_key_hash[TR_BYTES];
980 };
981 
982 struct private_key {
983   uint8_t rho[RHO_BYTES];
984   uint8_t k[K_BYTES];
985   uint8_t public_key_hash[TR_BYTES];
986   vectorl s1;
987   vectork s2;
988   vectork t0;
989 };
990 
991 struct signature {
992   uint8_t c_tilde[2 * LAMBDA_BYTES];
993   vectorl z;
994   vectork h;
995 };
996 
997 // FIPS 204, Algorithm 16 (`pkEncode`).
dilithium_marshal_public_key(CBB * out,const struct public_key * pub)998 static int dilithium_marshal_public_key(CBB *out,
999                                         const struct public_key *pub) {
1000   if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) {
1001     return 0;
1002   }
1003 
1004   uint8_t *vectork_output;
1005   if (!CBB_add_space(out, &vectork_output, 320 * K)) {
1006     return 0;
1007   }
1008   vectork_encode(vectork_output, &pub->t1, 10);
1009 
1010   return 1;
1011 }
1012 
1013 // FIPS 204, Algorithm 17 (`pkDecode`).
dilithium_parse_public_key(struct public_key * pub,CBS * in)1014 static int dilithium_parse_public_key(struct public_key *pub, CBS *in) {
1015   if (!CBS_copy_bytes(in, pub->rho, sizeof(pub->rho))) {
1016     return 0;
1017   }
1018 
1019   CBS t1_bytes;
1020   if (!CBS_get_bytes(in, &t1_bytes, 320 * K)) {
1021     return 0;
1022   }
1023   vectork_decode(&pub->t1, CBS_data(&t1_bytes), 10);
1024 
1025   return 1;
1026 }
1027 
1028 // FIPS 204, Algorithm 18 (`skEncode`).
dilithium_marshal_private_key(CBB * out,const struct private_key * priv)1029 static int dilithium_marshal_private_key(CBB *out,
1030                                          const struct private_key *priv) {
1031   if (!CBB_add_bytes(out, priv->rho, sizeof(priv->rho)) ||
1032       !CBB_add_bytes(out, priv->k, sizeof(priv->k)) ||
1033       !CBB_add_bytes(out, priv->public_key_hash,
1034                      sizeof(priv->public_key_hash))) {
1035     return 0;
1036   }
1037 
1038   uint8_t *vectorl_output;
1039   if (!CBB_add_space(out, &vectorl_output, 128 * L)) {
1040     return 0;
1041   }
1042   vectorl_encode_signed(vectorl_output, &priv->s1, 4, ETA);
1043 
1044   uint8_t *vectork_output;
1045   if (!CBB_add_space(out, &vectork_output, 128 * K)) {
1046     return 0;
1047   }
1048   vectork_encode_signed(vectork_output, &priv->s2, 4, ETA);
1049 
1050   if (!CBB_add_space(out, &vectork_output, 416 * K)) {
1051     return 0;
1052   }
1053   vectork_encode_signed(vectork_output, &priv->t0, 13, 1 << 12);
1054 
1055   return 1;
1056 }
1057 
1058 // FIPS 204, Algorithm 19 (`skDecode`).
dilithium_parse_private_key(struct private_key * priv,CBS * in)1059 static int dilithium_parse_private_key(struct private_key *priv, CBS *in) {
1060   CBS s1_bytes;
1061   CBS s2_bytes;
1062   CBS t0_bytes;
1063   if (!CBS_copy_bytes(in, priv->rho, sizeof(priv->rho)) ||
1064       !CBS_copy_bytes(in, priv->k, sizeof(priv->k)) ||
1065       !CBS_copy_bytes(in, priv->public_key_hash,
1066                       sizeof(priv->public_key_hash)) ||
1067       !CBS_get_bytes(in, &s1_bytes, 128 * L) ||
1068       !vectorl_decode_signed(&priv->s1, CBS_data(&s1_bytes), 4, ETA) ||
1069       !CBS_get_bytes(in, &s2_bytes, 128 * K) ||
1070       !vectork_decode_signed(&priv->s2, CBS_data(&s2_bytes), 4, ETA) ||
1071       !CBS_get_bytes(in, &t0_bytes, 416 * K) ||
1072       // Note: Decoding 13 bits into (-2^12, 2^12] cannot fail.
1073       !vectork_decode_signed(&priv->t0, CBS_data(&t0_bytes), 13, 1 << 12)) {
1074     return 0;
1075   }
1076 
1077   return 1;
1078 }
1079 
1080 // FIPS 204, Algorithm 20 (`sigEncode`).
dilithium_marshal_signature(CBB * out,const struct signature * sign)1081 static int dilithium_marshal_signature(CBB *out, const struct signature *sign) {
1082   if (!CBB_add_bytes(out, sign->c_tilde, sizeof(sign->c_tilde))) {
1083     return 0;
1084   }
1085 
1086   uint8_t *vectorl_output;
1087   if (!CBB_add_space(out, &vectorl_output, 640 * L)) {
1088     return 0;
1089   }
1090   vectorl_encode_signed(vectorl_output, &sign->z, 20, 1 << 19);
1091 
1092   uint8_t *hint_output;
1093   if (!CBB_add_space(out, &hint_output, OMEGA + K)) {
1094     return 0;
1095   }
1096   hint_bit_pack(hint_output, &sign->h);
1097 
1098   return 1;
1099 }
1100 
1101 // FIPS 204, Algorithm 21 (`sigDecode`).
dilithium_parse_signature(struct signature * sign,CBS * in)1102 static int dilithium_parse_signature(struct signature *sign, CBS *in) {
1103   CBS z_bytes;
1104   CBS hint_bytes;
1105   if (!CBS_copy_bytes(in, sign->c_tilde, sizeof(sign->c_tilde)) ||
1106       !CBS_get_bytes(in, &z_bytes, 640 * L) ||
1107       // Note: Decoding 20 bits into (-2^19, 2^19] cannot fail.
1108       !vectorl_decode_signed(&sign->z, CBS_data(&z_bytes), 20, 1 << 19) ||
1109       !CBS_get_bytes(in, &hint_bytes, OMEGA + K) ||
1110       !hint_bit_unpack(&sign->h, CBS_data(&hint_bytes))) {
1111     return 0;
1112   };
1113 
1114   return 1;
1115 }
1116 
private_key_from_external(const struct DILITHIUM_private_key * external)1117 static struct private_key *private_key_from_external(
1118     const struct DILITHIUM_private_key *external) {
1119   static_assert(
1120       sizeof(struct DILITHIUM_private_key) == sizeof(struct private_key),
1121       "Kyber private key size incorrect");
1122   static_assert(
1123       alignof(struct DILITHIUM_private_key) == alignof(struct private_key),
1124       "Kyber private key align incorrect");
1125   return (struct private_key *)external;
1126 }
1127 
public_key_from_external(const struct DILITHIUM_public_key * external)1128 static struct public_key *public_key_from_external(
1129     const struct DILITHIUM_public_key *external) {
1130   static_assert(
1131       sizeof(struct DILITHIUM_public_key) == sizeof(struct public_key),
1132       "Dilithium public key size incorrect");
1133   static_assert(
1134       alignof(struct DILITHIUM_public_key) == alignof(struct public_key),
1135       "Dilithium public key align incorrect");
1136   return (struct public_key *)external;
1137 }
1138 
1139 /* API */
1140 
1141 // Calls |DILITHIUM_generate_key_external_entropy| with random bytes from
1142 // |RAND_bytes|. Returns 1 on success and 0 on failure.
DILITHIUM_generate_key(uint8_t out_encoded_public_key[DILITHIUM_PUBLIC_KEY_BYTES],struct DILITHIUM_private_key * out_private_key)1143 int DILITHIUM_generate_key(
1144     uint8_t out_encoded_public_key[DILITHIUM_PUBLIC_KEY_BYTES],
1145     struct DILITHIUM_private_key *out_private_key) {
1146   uint8_t entropy[DILITHIUM_GENERATE_KEY_ENTROPY];
1147   RAND_bytes(entropy, sizeof(entropy));
1148   return DILITHIUM_generate_key_external_entropy(out_encoded_public_key,
1149                                                  out_private_key, entropy);
1150 }
1151 
1152 // FIPS 204, Algorithm 1 (`ML-DSA.KeyGen`). Returns 1 on success and 0 on
1153 // failure.
DILITHIUM_generate_key_external_entropy(uint8_t out_encoded_public_key[DILITHIUM_PUBLIC_KEY_BYTES],struct DILITHIUM_private_key * out_private_key,const uint8_t entropy[DILITHIUM_GENERATE_KEY_ENTROPY])1154 int DILITHIUM_generate_key_external_entropy(
1155     uint8_t out_encoded_public_key[DILITHIUM_PUBLIC_KEY_BYTES],
1156     struct DILITHIUM_private_key *out_private_key,
1157     const uint8_t entropy[DILITHIUM_GENERATE_KEY_ENTROPY]) {
1158   int ret = 0;
1159 
1160   // Intermediate values, allocated on the heap to allow use when there is a
1161   // limited amount of stack.
1162   struct values_st {
1163     struct public_key pub;
1164     matrix a_ntt;
1165     vectorl s1_ntt;
1166     vectork t;
1167   };
1168   struct values_st *values = OPENSSL_malloc(sizeof(*values));
1169   if (values == NULL) {
1170     goto err;
1171   }
1172 
1173   struct private_key *priv = private_key_from_external(out_private_key);
1174 
1175   uint8_t expanded_seed[RHO_BYTES + SIGMA_BYTES + K_BYTES];
1176   BORINGSSL_keccak(expanded_seed, sizeof(expanded_seed), entropy,
1177                    DILITHIUM_GENERATE_KEY_ENTROPY, boringssl_shake256);
1178   const uint8_t *const rho = expanded_seed;
1179   const uint8_t *const sigma = expanded_seed + RHO_BYTES;
1180   const uint8_t *const k = expanded_seed + RHO_BYTES + SIGMA_BYTES;
1181   // rho is public.
1182   CONSTTIME_DECLASSIFY(rho, RHO_BYTES);
1183   OPENSSL_memcpy(values->pub.rho, rho, sizeof(values->pub.rho));
1184   OPENSSL_memcpy(priv->rho, rho, sizeof(priv->rho));
1185   OPENSSL_memcpy(priv->k, k, sizeof(priv->k));
1186 
1187   matrix_expand(&values->a_ntt, rho);
1188   vector_expand_short(&priv->s1, &priv->s2, sigma);
1189 
1190   OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
1191   vectorl_ntt(&values->s1_ntt);
1192 
1193   matrix_mult(&values->t, &values->a_ntt, &values->s1_ntt);
1194   vectork_inverse_ntt(&values->t);
1195   vectork_add(&values->t, &values->t, &priv->s2);
1196 
1197   vectork_power2_round(&values->pub.t1, &priv->t0, &values->t);
1198   // t1 is public.
1199   CONSTTIME_DECLASSIFY(&values->pub.t1, sizeof(values->pub.t1));
1200 
1201   CBB cbb;
1202   CBB_init_fixed(&cbb, out_encoded_public_key, DILITHIUM_PUBLIC_KEY_BYTES);
1203   if (!dilithium_marshal_public_key(&cbb, &values->pub)) {
1204     goto err;
1205   }
1206 
1207   BORINGSSL_keccak(priv->public_key_hash, sizeof(priv->public_key_hash),
1208                    out_encoded_public_key, DILITHIUM_PUBLIC_KEY_BYTES,
1209                    boringssl_shake256);
1210 
1211   ret = 1;
1212 err:
1213   OPENSSL_free(values);
1214   return ret;
1215 }
1216 
DILITHIUM_public_from_private(struct DILITHIUM_public_key * out_public_key,const struct DILITHIUM_private_key * private_key)1217 int DILITHIUM_public_from_private(
1218     struct DILITHIUM_public_key *out_public_key,
1219     const struct DILITHIUM_private_key *private_key) {
1220   int ret = 0;
1221 
1222   // Intermediate values, allocated on the heap to allow use when there is a
1223   // limited amount of stack.
1224   struct values_st {
1225     matrix a_ntt;
1226     vectorl s1_ntt;
1227     vectork t;
1228     vectork t0;
1229   };
1230   struct values_st *values = OPENSSL_malloc(sizeof(*values));
1231   if (values == NULL) {
1232     goto err;
1233   }
1234 
1235   const struct private_key *priv = private_key_from_external(private_key);
1236   struct public_key *pub = public_key_from_external(out_public_key);
1237 
1238   OPENSSL_memcpy(pub->rho, priv->rho, sizeof(pub->rho));
1239   OPENSSL_memcpy(pub->public_key_hash, priv->public_key_hash,
1240                  sizeof(pub->public_key_hash));
1241 
1242   matrix_expand(&values->a_ntt, priv->rho);
1243 
1244   OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
1245   vectorl_ntt(&values->s1_ntt);
1246 
1247   matrix_mult(&values->t, &values->a_ntt, &values->s1_ntt);
1248   vectork_inverse_ntt(&values->t);
1249   vectork_add(&values->t, &values->t, &priv->s2);
1250 
1251   vectork_power2_round(&pub->t1, &values->t0, &values->t);
1252 
1253   ret = 1;
1254 err:
1255   OPENSSL_free(values);
1256   return ret;
1257 }
1258 
1259 // FIPS 204, Algorithm 2 (`ML-DSA.Sign`). Returns 1 on success and 0 on failure.
dilithium_sign_with_randomizer(uint8_t out_encoded_signature[DILITHIUM_SIGNATURE_BYTES],const struct DILITHIUM_private_key * private_key,const uint8_t * msg,size_t msg_len,const uint8_t randomizer[DILITHIUM_SIGNATURE_RANDOMIZER_BYTES])1260 static int dilithium_sign_with_randomizer(
1261     uint8_t out_encoded_signature[DILITHIUM_SIGNATURE_BYTES],
1262     const struct DILITHIUM_private_key *private_key, const uint8_t *msg,
1263     size_t msg_len,
1264     const uint8_t randomizer[DILITHIUM_SIGNATURE_RANDOMIZER_BYTES]) {
1265   int ret = 0;
1266 
1267   const struct private_key *priv = private_key_from_external(private_key);
1268 
1269   uint8_t mu[MU_BYTES];
1270   struct BORINGSSL_keccak_st keccak_ctx;
1271   BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
1272   BORINGSSL_keccak_absorb(&keccak_ctx, priv->public_key_hash,
1273                           sizeof(priv->public_key_hash));
1274   BORINGSSL_keccak_absorb(&keccak_ctx, msg, msg_len);
1275   BORINGSSL_keccak_squeeze(&keccak_ctx, mu, MU_BYTES);
1276 
1277   uint8_t rho_prime[RHO_PRIME_BYTES];
1278   BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
1279   BORINGSSL_keccak_absorb(&keccak_ctx, priv->k, sizeof(priv->k));
1280   BORINGSSL_keccak_absorb(&keccak_ctx, randomizer,
1281                           DILITHIUM_SIGNATURE_RANDOMIZER_BYTES);
1282   BORINGSSL_keccak_absorb(&keccak_ctx, mu, MU_BYTES);
1283   BORINGSSL_keccak_squeeze(&keccak_ctx, rho_prime, RHO_PRIME_BYTES);
1284 
1285   // Intermediate values, allocated on the heap to allow use when there is a
1286   // limited amount of stack.
1287   struct values_st {
1288     struct signature sign;
1289     vectorl s1_ntt;
1290     vectork s2_ntt;
1291     vectork t0_ntt;
1292     matrix a_ntt;
1293     vectorl y;
1294     vectorl y_ntt;
1295     vectork w;
1296     vectork w1;
1297     vectorl cs1;
1298     vectork cs2;
1299     vectork r0;
1300     vectork ct0;
1301   };
1302   struct values_st *values = OPENSSL_malloc(sizeof(*values));
1303   if (values == NULL) {
1304     goto err;
1305   }
1306   OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
1307   vectorl_ntt(&values->s1_ntt);
1308 
1309   OPENSSL_memcpy(&values->s2_ntt, &priv->s2, sizeof(values->s2_ntt));
1310   vectork_ntt(&values->s2_ntt);
1311 
1312   OPENSSL_memcpy(&values->t0_ntt, &priv->t0, sizeof(values->t0_ntt));
1313   vectork_ntt(&values->t0_ntt);
1314 
1315   matrix_expand(&values->a_ntt, priv->rho);
1316 
1317   for (size_t kappa = 0;; kappa += L) {
1318     // TODO(bbe): y only lives long enough to compute y_ntt.
1319     // consider using another vectorl to save memory.
1320     vectorl_expand_mask(&values->y, rho_prime, kappa);
1321 
1322     OPENSSL_memcpy(&values->y_ntt, &values->y, sizeof(values->y_ntt));
1323     vectorl_ntt(&values->y_ntt);
1324 
1325     // TODO(bbe): w only lives long enough to compute y_ntt.
1326     // consider using another vectork to save memory.
1327     matrix_mult(&values->w, &values->a_ntt, &values->y_ntt);
1328     vectork_inverse_ntt(&values->w);
1329 
1330     vectork_high_bits(&values->w1, &values->w);
1331     uint8_t w1_encoded[128 * K];
1332     w1_encode(w1_encoded, &values->w1);
1333 
1334     BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
1335     BORINGSSL_keccak_absorb(&keccak_ctx, mu, MU_BYTES);
1336     BORINGSSL_keccak_absorb(&keccak_ctx, w1_encoded, 128 * K);
1337     BORINGSSL_keccak_squeeze(&keccak_ctx, values->sign.c_tilde,
1338                              2 * LAMBDA_BYTES);
1339 
1340     scalar c_ntt;
1341     scalar_sample_in_ball_vartime(&c_ntt, values->sign.c_tilde, 32);
1342     scalar_ntt(&c_ntt);
1343 
1344     vectorl_mult_scalar(&values->cs1, &values->s1_ntt, &c_ntt);
1345     vectorl_inverse_ntt(&values->cs1);
1346     vectork_mult_scalar(&values->cs2, &values->s2_ntt, &c_ntt);
1347     vectork_inverse_ntt(&values->cs2);
1348 
1349     vectorl_add(&values->sign.z, &values->y, &values->cs1);
1350 
1351     vectork_sub(&values->r0, &values->w, &values->cs2);
1352     vectork_low_bits(&values->r0, &values->r0);
1353 
1354     // Leaking the fact that a signature was rejected is fine as the next
1355     // attempt at a signature will be (indistinguishable from) independent of
1356     // this one. Note, however, that we additionally leak which of the two
1357     // branches rejected the signature. Section 5.5 of
1358     // https://pq-crystals.org/dilithium/data/dilithium-specification-round3.pdf
1359     // describes this leak as OK. Note we leak less than what is described by
1360     // the paper; we do not reveal which coefficient violated the bound, and we
1361     // hide which of the |z_max| or |r0_max| bound failed. See also
1362     // https://boringssl-review.googlesource.com/c/boringssl/+/67747/comment/2bbab0fa_d241d35a/
1363     uint32_t z_max = vectorl_max(&values->sign.z);
1364     uint32_t r0_max = vectork_max_signed(&values->r0);
1365     if (constant_time_declassify_w(
1366             constant_time_ge_w(z_max, kGamma1 - BETA) |
1367             constant_time_ge_w(r0_max, kGamma2 - BETA))) {
1368       continue;
1369     }
1370 
1371     vectork_mult_scalar(&values->ct0, &values->t0_ntt, &c_ntt);
1372     vectork_inverse_ntt(&values->ct0);
1373     vectork_make_hint(&values->sign.h, &values->ct0, &values->cs2, &values->w);
1374 
1375     // See above.
1376     uint32_t ct0_max = vectork_max(&values->ct0);
1377     size_t h_ones = vectork_count_ones(&values->sign.h);
1378     if (constant_time_declassify_w(constant_time_ge_w(ct0_max, kGamma2) |
1379                                    constant_time_lt_w(OMEGA, h_ones))) {
1380       continue;
1381     }
1382 
1383     // Although computed with the private key, the signature is public.
1384     CONSTTIME_DECLASSIFY(values->sign.c_tilde, sizeof(values->sign.c_tilde));
1385     CONSTTIME_DECLASSIFY(&values->sign.z, sizeof(values->sign.z));
1386     CONSTTIME_DECLASSIFY(&values->sign.h, sizeof(values->sign.h));
1387 
1388     CBB cbb;
1389     CBB_init_fixed(&cbb, out_encoded_signature, DILITHIUM_SIGNATURE_BYTES);
1390     if (!dilithium_marshal_signature(&cbb, &values->sign)) {
1391       goto err;
1392     }
1393 
1394     BSSL_CHECK(CBB_len(&cbb) == DILITHIUM_SIGNATURE_BYTES);
1395     ret = 1;
1396     break;
1397   }
1398 
1399 err:
1400   OPENSSL_free(values);
1401   return ret;
1402 }
1403 
1404 // Dilithium signature in deterministic mode. Returns 1 on success and 0 on
1405 // failure.
DILITHIUM_sign_deterministic(uint8_t out_encoded_signature[DILITHIUM_SIGNATURE_BYTES],const struct DILITHIUM_private_key * private_key,const uint8_t * msg,size_t msg_len)1406 int DILITHIUM_sign_deterministic(
1407     uint8_t out_encoded_signature[DILITHIUM_SIGNATURE_BYTES],
1408     const struct DILITHIUM_private_key *private_key, const uint8_t *msg,
1409     size_t msg_len) {
1410   uint8_t randomizer[DILITHIUM_SIGNATURE_RANDOMIZER_BYTES];
1411   OPENSSL_memset(randomizer, 0, sizeof(randomizer));
1412   return dilithium_sign_with_randomizer(out_encoded_signature, private_key, msg,
1413                                         msg_len, randomizer);
1414 }
1415 
1416 // Dilithium signature in randomized mode, filling the random bytes with
1417 // |RAND_bytes|. Returns 1 on success and 0 on failure.
DILITHIUM_sign(uint8_t out_encoded_signature[DILITHIUM_SIGNATURE_BYTES],const struct DILITHIUM_private_key * private_key,const uint8_t * msg,size_t msg_len)1418 int DILITHIUM_sign(uint8_t out_encoded_signature[DILITHIUM_SIGNATURE_BYTES],
1419                    const struct DILITHIUM_private_key *private_key,
1420                    const uint8_t *msg, size_t msg_len) {
1421   uint8_t randomizer[DILITHIUM_SIGNATURE_RANDOMIZER_BYTES];
1422   RAND_bytes(randomizer, sizeof(randomizer));
1423   return dilithium_sign_with_randomizer(out_encoded_signature, private_key, msg,
1424                                         msg_len, randomizer);
1425 }
1426 
1427 // FIPS 204, Algorithm 3 (`ML-DSA.Verify`).
DILITHIUM_verify(const struct DILITHIUM_public_key * public_key,const uint8_t encoded_signature[DILITHIUM_SIGNATURE_BYTES],const uint8_t * msg,size_t msg_len)1428 int DILITHIUM_verify(const struct DILITHIUM_public_key *public_key,
1429                      const uint8_t encoded_signature[DILITHIUM_SIGNATURE_BYTES],
1430                      const uint8_t *msg, size_t msg_len) {
1431   int ret = 0;
1432 
1433   // Intermediate values, allocated on the heap to allow use when there is a
1434   // limited amount of stack.
1435   struct values_st {
1436     struct signature sign;
1437     matrix a_ntt;
1438     vectorl z_ntt;
1439     vectork az_ntt;
1440     vectork t1_ntt;
1441     vectork ct1_ntt;
1442     vectork w_approx;
1443     vectork w1;
1444   };
1445   struct values_st *values = OPENSSL_malloc(sizeof(*values));
1446   if (values == NULL) {
1447     goto err;
1448   }
1449 
1450   const struct public_key *pub = public_key_from_external(public_key);
1451 
1452   CBS cbs;
1453   CBS_init(&cbs, encoded_signature, DILITHIUM_SIGNATURE_BYTES);
1454   if (!dilithium_parse_signature(&values->sign, &cbs)) {
1455     goto err;
1456   }
1457 
1458   matrix_expand(&values->a_ntt, pub->rho);
1459 
1460   uint8_t mu[MU_BYTES];
1461   struct BORINGSSL_keccak_st keccak_ctx;
1462   BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
1463   BORINGSSL_keccak_absorb(&keccak_ctx, pub->public_key_hash,
1464                           sizeof(pub->public_key_hash));
1465   BORINGSSL_keccak_absorb(&keccak_ctx, msg, msg_len);
1466   BORINGSSL_keccak_squeeze(&keccak_ctx, mu, MU_BYTES);
1467 
1468   scalar c_ntt;
1469   scalar_sample_in_ball_vartime(&c_ntt, values->sign.c_tilde, 32);
1470   scalar_ntt(&c_ntt);
1471 
1472   OPENSSL_memcpy(&values->z_ntt, &values->sign.z, sizeof(values->z_ntt));
1473   vectorl_ntt(&values->z_ntt);
1474 
1475   matrix_mult(&values->az_ntt, &values->a_ntt, &values->z_ntt);
1476 
1477   vectork_scale_power2_round(&values->t1_ntt, &pub->t1);
1478   vectork_ntt(&values->t1_ntt);
1479 
1480   vectork_mult_scalar(&values->ct1_ntt, &values->t1_ntt, &c_ntt);
1481 
1482   vectork_sub(&values->w_approx, &values->az_ntt, &values->ct1_ntt);
1483   vectork_inverse_ntt(&values->w_approx);
1484 
1485   vectork_use_hint_vartime(&values->w1, &values->sign.h, &values->w_approx);
1486   uint8_t w1_encoded[128 * K];
1487   w1_encode(w1_encoded, &values->w1);
1488 
1489   uint8_t c_tilde[2 * LAMBDA_BYTES];
1490   BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
1491   BORINGSSL_keccak_absorb(&keccak_ctx, mu, MU_BYTES);
1492   BORINGSSL_keccak_absorb(&keccak_ctx, w1_encoded, 128 * K);
1493   BORINGSSL_keccak_squeeze(&keccak_ctx, c_tilde, 2 * LAMBDA_BYTES);
1494 
1495   uint32_t z_max = vectorl_max(&values->sign.z);
1496   size_t h_ones = vectork_count_ones(&values->sign.h);
1497   if (z_max < kGamma1 - BETA && h_ones <= OMEGA &&
1498       OPENSSL_memcmp(c_tilde, values->sign.c_tilde, 2 * LAMBDA_BYTES) == 0) {
1499     ret = 1;
1500   }
1501 
1502 err:
1503   OPENSSL_free(values);
1504   return ret;
1505 }
1506 
1507 /* Serialization of keys. */
1508 
DILITHIUM_marshal_public_key(CBB * out,const struct DILITHIUM_public_key * public_key)1509 int DILITHIUM_marshal_public_key(
1510     CBB *out, const struct DILITHIUM_public_key *public_key) {
1511   return dilithium_marshal_public_key(out,
1512                                       public_key_from_external(public_key));
1513 }
1514 
DILITHIUM_parse_public_key(struct DILITHIUM_public_key * public_key,CBS * in)1515 int DILITHIUM_parse_public_key(struct DILITHIUM_public_key *public_key,
1516                                CBS *in) {
1517   struct public_key *pub = public_key_from_external(public_key);
1518   CBS orig_in = *in;
1519   if (!dilithium_parse_public_key(pub, in) || CBS_len(in) != 0) {
1520     return 0;
1521   }
1522 
1523   // Compute pre-cached values.
1524   BORINGSSL_keccak(pub->public_key_hash, sizeof(pub->public_key_hash),
1525                    CBS_data(&orig_in), CBS_len(&orig_in), boringssl_shake256);
1526   return 1;
1527 }
1528 
DILITHIUM_marshal_private_key(CBB * out,const struct DILITHIUM_private_key * private_key)1529 int DILITHIUM_marshal_private_key(
1530     CBB *out, const struct DILITHIUM_private_key *private_key) {
1531   return dilithium_marshal_private_key(out,
1532                                        private_key_from_external(private_key));
1533 }
1534 
DILITHIUM_parse_private_key(struct DILITHIUM_private_key * private_key,CBS * in)1535 int DILITHIUM_parse_private_key(struct DILITHIUM_private_key *private_key,
1536                                 CBS *in) {
1537   struct private_key *priv = private_key_from_external(private_key);
1538   return dilithium_parse_private_key(priv, in) && CBS_len(in) == 0;
1539 }
1540