xref: /aosp_15_r20/external/cronet/third_party/boringssl/src/crypto/spx/wots.c (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 /* Copyright (c) 2023, Google LLC
2  *
3  * Permission to use, copy, modify, and/or distribute this software for any
4  * purpose with or without fee is hereby granted, provided that the above
5  * copyright notice and this permission notice appear in all copies.
6  *
7  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10  * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12  * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13  * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
14 
15 #include <openssl/base.h>
16 
17 #include <stdint.h>
18 #include <stdio.h>
19 #include <string.h>
20 
21 #include "./address.h"
22 #include "./params.h"
23 #include "./spx_util.h"
24 #include "./thash.h"
25 #include "./wots.h"
26 
27 // Chaining function used in WOTS+.
chain(uint8_t * output,const uint8_t * input,uint32_t start,uint32_t steps,const uint8_t * pub_seed,uint8_t addr[32])28 static void chain(uint8_t *output, const uint8_t *input, uint32_t start,
29                   uint32_t steps, const uint8_t *pub_seed, uint8_t addr[32]) {
30   memcpy(output, input, SPX_N);
31 
32   for (size_t i = start; i < (start + steps) && i < SPX_WOTS_W; ++i) {
33     spx_set_hash_addr(addr, i);
34     spx_thash_f(output, output, pub_seed, addr);
35   }
36 }
37 
spx_wots_pk_from_sig(uint8_t * pk,const uint8_t * sig,const uint8_t * msg,const uint8_t pub_seed[SPX_N],uint8_t addr[32])38 void spx_wots_pk_from_sig(uint8_t *pk, const uint8_t *sig, const uint8_t *msg,
39                           const uint8_t pub_seed[SPX_N], uint8_t addr[32]) {
40   uint8_t tmp[SPX_WOTS_BYTES];
41   uint8_t wots_pk_addr[32];
42   memcpy(wots_pk_addr, addr, sizeof(wots_pk_addr));
43 
44   // Convert message to base w
45   uint32_t base_w_msg[SPX_WOTS_LEN];
46   spx_base_b(base_w_msg, SPX_WOTS_LEN1, msg, /*log2_b=*/SPX_WOTS_LOG_W);
47 
48   // Compute checksum
49   uint64_t csum = 0;
50   for (size_t i = 0; i < SPX_WOTS_LEN1; ++i) {
51     csum += SPX_WOTS_W - 1 - base_w_msg[i];
52   }
53 
54   // Convert csum to base w as in Algorithm 7, Line 9
55   uint8_t csum_bytes[(SPX_WOTS_LEN2 * SPX_WOTS_LOG_W + 7) / 8];
56   csum = csum << ((8 - ((SPX_WOTS_LEN2 * SPX_WOTS_LOG_W)) % 8) % 8);
57   spx_uint64_to_len_bytes(csum_bytes, sizeof(csum_bytes), csum);
58 
59   // Write the base w representation of csum to the end of the message.
60   spx_base_b(base_w_msg + SPX_WOTS_LEN1, SPX_WOTS_LEN2, csum_bytes,
61              /*log2_b=*/SPX_WOTS_LOG_W);
62 
63   // Compute chains
64   for (size_t i = 0; i < SPX_WOTS_LEN; ++i) {
65     spx_set_chain_addr(addr, i);
66     chain(tmp + i * SPX_N, sig + i * SPX_N, base_w_msg[i],
67           SPX_WOTS_W - 1 - base_w_msg[i], pub_seed, addr);
68   }
69 
70   // Compress pk
71   spx_set_type(wots_pk_addr, SPX_ADDR_TYPE_WOTSPK);
72   spx_copy_keypair_addr(wots_pk_addr, addr);
73   spx_thash_tl(pk, tmp, pub_seed, wots_pk_addr);
74 }
75 
spx_wots_pk_gen(uint8_t * pk,const uint8_t sk_seed[SPX_N],const uint8_t pub_seed[SPX_N],uint8_t addr[32])76 void spx_wots_pk_gen(uint8_t *pk, const uint8_t sk_seed[SPX_N],
77                      const uint8_t pub_seed[SPX_N], uint8_t addr[32]) {
78   uint8_t tmp[SPX_WOTS_BYTES];
79   uint8_t tmp_sk[SPX_N];
80   uint8_t wots_pk_addr[32], sk_addr[32];
81   memcpy(wots_pk_addr, addr, sizeof(wots_pk_addr));
82   memcpy(sk_addr, addr, sizeof(sk_addr));
83 
84   spx_set_type(sk_addr, SPX_ADDR_TYPE_WOTSPRF);
85   spx_copy_keypair_addr(sk_addr, addr);
86 
87   for (size_t i = 0; i < SPX_WOTS_LEN; ++i) {
88     spx_set_chain_addr(sk_addr, i);
89     spx_thash_prf(tmp_sk, pub_seed, sk_seed, sk_addr);
90     spx_set_chain_addr(addr, i);
91     chain(tmp + i * SPX_N, tmp_sk, 0, SPX_WOTS_W - 1, pub_seed, addr);
92   }
93 
94   // Compress pk
95   spx_set_type(wots_pk_addr, SPX_ADDR_TYPE_WOTSPK);
96   spx_copy_keypair_addr(wots_pk_addr, addr);
97   spx_thash_tl(pk, tmp, pub_seed, wots_pk_addr);
98 }
99 
spx_wots_sign(uint8_t * sig,const uint8_t msg[SPX_N],const uint8_t sk_seed[SPX_N],const uint8_t pub_seed[SPX_N],uint8_t addr[32])100 void spx_wots_sign(uint8_t *sig, const uint8_t msg[SPX_N],
101                    const uint8_t sk_seed[SPX_N], const uint8_t pub_seed[SPX_N],
102                    uint8_t addr[32]) {
103   // Convert message to base w
104   uint32_t base_w_msg[SPX_WOTS_LEN];
105   spx_base_b(base_w_msg, SPX_WOTS_LEN1, msg, /*log2_b=*/SPX_WOTS_LOG_W);
106 
107   // Compute checksum
108   uint64_t csum = 0;
109   for (size_t i = 0; i < SPX_WOTS_LEN1; ++i) {
110     csum += SPX_WOTS_W - 1 - base_w_msg[i];
111   }
112 
113   // Convert csum to base w as in Algorithm 6, Line 9
114   uint8_t csum_bytes[(SPX_WOTS_LEN2 * SPX_WOTS_LOG_W + 7) / 8];
115   csum = csum << ((8 - ((SPX_WOTS_LEN2 * SPX_WOTS_LOG_W)) % 8) % 8);
116   spx_uint64_to_len_bytes(csum_bytes, sizeof(csum_bytes), csum);
117 
118   // Write the base w representation of csum to the end of the message.
119   spx_base_b(base_w_msg + SPX_WOTS_LEN1, SPX_WOTS_LEN2, csum_bytes,
120              /*log2_b=*/SPX_WOTS_LOG_W);
121 
122   // Compute chains
123   uint8_t tmp_sk[SPX_N];
124   uint8_t sk_addr[32];
125   memcpy(sk_addr, addr, sizeof(sk_addr));
126   spx_set_type(sk_addr, SPX_ADDR_TYPE_WOTSPRF);
127   spx_copy_keypair_addr(sk_addr, addr);
128 
129   for (size_t i = 0; i < SPX_WOTS_LEN; ++i) {
130     spx_set_chain_addr(sk_addr, i);
131     spx_thash_prf(tmp_sk, pub_seed, sk_seed, sk_addr);
132     spx_set_chain_addr(addr, i);
133     chain(sig + i * SPX_N, tmp_sk, 0, base_w_msg[i], pub_seed, addr);
134   }
135 }
136