1 /* Copyright (c) 2023, Google LLC
2 *
3 * Permission to use, copy, modify, and/or distribute this software for any
4 * purpose with or without fee is hereby granted, provided that the above
5 * copyright notice and this permission notice appear in all copies.
6 *
7 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
14
15 #include <openssl/base.h>
16
17 #include <stdint.h>
18 #include <stdio.h>
19 #include <string.h>
20
21 #include "./address.h"
22 #include "./params.h"
23 #include "./spx_util.h"
24 #include "./thash.h"
25 #include "./wots.h"
26
27 // Chaining function used in WOTS+.
chain(uint8_t * output,const uint8_t * input,uint32_t start,uint32_t steps,const uint8_t * pub_seed,uint8_t addr[32])28 static void chain(uint8_t *output, const uint8_t *input, uint32_t start,
29 uint32_t steps, const uint8_t *pub_seed, uint8_t addr[32]) {
30 memcpy(output, input, SPX_N);
31
32 for (size_t i = start; i < (start + steps) && i < SPX_WOTS_W; ++i) {
33 spx_set_hash_addr(addr, i);
34 spx_thash_f(output, output, pub_seed, addr);
35 }
36 }
37
spx_wots_pk_from_sig(uint8_t * pk,const uint8_t * sig,const uint8_t * msg,const uint8_t pub_seed[SPX_N],uint8_t addr[32])38 void spx_wots_pk_from_sig(uint8_t *pk, const uint8_t *sig, const uint8_t *msg,
39 const uint8_t pub_seed[SPX_N], uint8_t addr[32]) {
40 uint8_t tmp[SPX_WOTS_BYTES];
41 uint8_t wots_pk_addr[32];
42 memcpy(wots_pk_addr, addr, sizeof(wots_pk_addr));
43
44 // Convert message to base w
45 uint32_t base_w_msg[SPX_WOTS_LEN];
46 spx_base_b(base_w_msg, SPX_WOTS_LEN1, msg, /*log2_b=*/SPX_WOTS_LOG_W);
47
48 // Compute checksum
49 uint64_t csum = 0;
50 for (size_t i = 0; i < SPX_WOTS_LEN1; ++i) {
51 csum += SPX_WOTS_W - 1 - base_w_msg[i];
52 }
53
54 // Convert csum to base w as in Algorithm 7, Line 9
55 uint8_t csum_bytes[(SPX_WOTS_LEN2 * SPX_WOTS_LOG_W + 7) / 8];
56 csum = csum << ((8 - ((SPX_WOTS_LEN2 * SPX_WOTS_LOG_W)) % 8) % 8);
57 spx_uint64_to_len_bytes(csum_bytes, sizeof(csum_bytes), csum);
58
59 // Write the base w representation of csum to the end of the message.
60 spx_base_b(base_w_msg + SPX_WOTS_LEN1, SPX_WOTS_LEN2, csum_bytes,
61 /*log2_b=*/SPX_WOTS_LOG_W);
62
63 // Compute chains
64 for (size_t i = 0; i < SPX_WOTS_LEN; ++i) {
65 spx_set_chain_addr(addr, i);
66 chain(tmp + i * SPX_N, sig + i * SPX_N, base_w_msg[i],
67 SPX_WOTS_W - 1 - base_w_msg[i], pub_seed, addr);
68 }
69
70 // Compress pk
71 spx_set_type(wots_pk_addr, SPX_ADDR_TYPE_WOTSPK);
72 spx_copy_keypair_addr(wots_pk_addr, addr);
73 spx_thash_tl(pk, tmp, pub_seed, wots_pk_addr);
74 }
75
spx_wots_pk_gen(uint8_t * pk,const uint8_t sk_seed[SPX_N],const uint8_t pub_seed[SPX_N],uint8_t addr[32])76 void spx_wots_pk_gen(uint8_t *pk, const uint8_t sk_seed[SPX_N],
77 const uint8_t pub_seed[SPX_N], uint8_t addr[32]) {
78 uint8_t tmp[SPX_WOTS_BYTES];
79 uint8_t tmp_sk[SPX_N];
80 uint8_t wots_pk_addr[32], sk_addr[32];
81 memcpy(wots_pk_addr, addr, sizeof(wots_pk_addr));
82 memcpy(sk_addr, addr, sizeof(sk_addr));
83
84 spx_set_type(sk_addr, SPX_ADDR_TYPE_WOTSPRF);
85 spx_copy_keypair_addr(sk_addr, addr);
86
87 for (size_t i = 0; i < SPX_WOTS_LEN; ++i) {
88 spx_set_chain_addr(sk_addr, i);
89 spx_thash_prf(tmp_sk, pub_seed, sk_seed, sk_addr);
90 spx_set_chain_addr(addr, i);
91 chain(tmp + i * SPX_N, tmp_sk, 0, SPX_WOTS_W - 1, pub_seed, addr);
92 }
93
94 // Compress pk
95 spx_set_type(wots_pk_addr, SPX_ADDR_TYPE_WOTSPK);
96 spx_copy_keypair_addr(wots_pk_addr, addr);
97 spx_thash_tl(pk, tmp, pub_seed, wots_pk_addr);
98 }
99
spx_wots_sign(uint8_t * sig,const uint8_t msg[SPX_N],const uint8_t sk_seed[SPX_N],const uint8_t pub_seed[SPX_N],uint8_t addr[32])100 void spx_wots_sign(uint8_t *sig, const uint8_t msg[SPX_N],
101 const uint8_t sk_seed[SPX_N], const uint8_t pub_seed[SPX_N],
102 uint8_t addr[32]) {
103 // Convert message to base w
104 uint32_t base_w_msg[SPX_WOTS_LEN];
105 spx_base_b(base_w_msg, SPX_WOTS_LEN1, msg, /*log2_b=*/SPX_WOTS_LOG_W);
106
107 // Compute checksum
108 uint64_t csum = 0;
109 for (size_t i = 0; i < SPX_WOTS_LEN1; ++i) {
110 csum += SPX_WOTS_W - 1 - base_w_msg[i];
111 }
112
113 // Convert csum to base w as in Algorithm 6, Line 9
114 uint8_t csum_bytes[(SPX_WOTS_LEN2 * SPX_WOTS_LOG_W + 7) / 8];
115 csum = csum << ((8 - ((SPX_WOTS_LEN2 * SPX_WOTS_LOG_W)) % 8) % 8);
116 spx_uint64_to_len_bytes(csum_bytes, sizeof(csum_bytes), csum);
117
118 // Write the base w representation of csum to the end of the message.
119 spx_base_b(base_w_msg + SPX_WOTS_LEN1, SPX_WOTS_LEN2, csum_bytes,
120 /*log2_b=*/SPX_WOTS_LOG_W);
121
122 // Compute chains
123 uint8_t tmp_sk[SPX_N];
124 uint8_t sk_addr[32];
125 memcpy(sk_addr, addr, sizeof(sk_addr));
126 spx_set_type(sk_addr, SPX_ADDR_TYPE_WOTSPRF);
127 spx_copy_keypair_addr(sk_addr, addr);
128
129 for (size_t i = 0; i < SPX_WOTS_LEN; ++i) {
130 spx_set_chain_addr(sk_addr, i);
131 spx_thash_prf(tmp_sk, pub_seed, sk_seed, sk_addr);
132 spx_set_chain_addr(addr, i);
133 chain(sig + i * SPX_N, tmp_sk, 0, base_w_msg[i], pub_seed, addr);
134 }
135 }
136