1 /*
2 * Copyright 2019 Google LLC.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * https://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "private_join_and_compute/crypto/big_num.h"
17
18 #include <cmath>
19 #include <string>
20 #include <utility>
21
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/string_view.h"
24 #include "private_join_and_compute/crypto/context.h"
25 #include "private_join_and_compute/crypto/openssl.inc"
26 #include "private_join_and_compute/util/status.inc"
27
28 namespace private_join_and_compute {
29
30 namespace {
31
32 // Utility class for decimal string conversion.
33 class BnString {
34 public:
BnString(char * bn_char)35 explicit BnString(char* bn_char) : bn_char_(bn_char) {}
36
~BnString()37 ~BnString() { OPENSSL_free(bn_char_); }
38
ToString()39 std::string ToString() { return std::string(bn_char_); }
40
41 private:
42 char* const bn_char_;
43 };
44
45 } // namespace
46
BigNum(const BigNum & other)47 BigNum::BigNum(const BigNum& other)
48 : bn_(BignumPtr(BN_dup(other.bn_.get()))), bn_ctx_(other.bn_ctx_) {}
49
operator =(const BigNum & other)50 BigNum& BigNum::operator=(const BigNum& other) {
51 BIGNUM* temp = BN_dup(other.bn_.get());
52 CHECK_NE(temp, nullptr);
53 bn_ = BignumPtr(temp);
54 bn_ctx_ = other.bn_ctx_;
55 return *this;
56 }
57
BigNum(BigNum && other)58 BigNum::BigNum(BigNum&& other)
59 : bn_(std::move(other.bn_)), bn_ctx_(other.bn_ctx_) {}
60
operator =(BigNum && other)61 BigNum& BigNum::operator=(BigNum&& other) {
62 bn_ = std::move(other.bn_);
63 bn_ctx_ = other.bn_ctx_;
64 return *this;
65 }
66
BigNum(BN_CTX * bn_ctx,uint64_t number)67 BigNum::BigNum(BN_CTX* bn_ctx, uint64_t number) : BigNum::BigNum(bn_ctx) {
68 CRYPTO_CHECK(BN_set_u64(bn_.get(), number));
69 }
70
BigNum(BN_CTX * bn_ctx,absl::string_view bytes)71 BigNum::BigNum(BN_CTX* bn_ctx, absl::string_view bytes)
72 : BigNum::BigNum(bn_ctx) {
73 CRYPTO_CHECK(nullptr !=
74 BN_bin2bn(reinterpret_cast<const unsigned char*>(bytes.data()),
75 bytes.size(), bn_.get()));
76 }
77
BigNum(BN_CTX * bn_ctx,const unsigned char * bytes,int length)78 BigNum::BigNum(BN_CTX* bn_ctx, const unsigned char* bytes, int length)
79 : BigNum::BigNum(bn_ctx) {
80 CRYPTO_CHECK(nullptr != BN_bin2bn(bytes, length, bn_.get()));
81 }
82
BigNum(BN_CTX * bn_ctx)83 BigNum::BigNum(BN_CTX* bn_ctx) {
84 BIGNUM* temp = BN_new();
85 CHECK_NE(temp, nullptr);
86 bn_ = BignumPtr(temp);
87 bn_ctx_ = bn_ctx;
88 }
89
BigNum(BN_CTX * bn_ctx,BignumPtr bn)90 BigNum::BigNum(BN_CTX* bn_ctx, BignumPtr bn) {
91 bn_ = std::move(bn);
92 bn_ctx_ = bn_ctx;
93 }
94
GetConstBignumPtr() const95 const BIGNUM* BigNum::GetConstBignumPtr() const { return bn_.get(); }
96
ToBytes() const97 std::string BigNum::ToBytes() const {
98 CHECK(IsNonNegative()) << "Cannot serialize a negative BigNum.";
99 int length = BN_num_bytes(bn_.get());
100
101 std::string bytes(length, 0);
102 BN_bn2bin(bn_.get(), reinterpret_cast<unsigned char*>(bytes.data()));
103 return bytes;
104 }
105
ToIntValue() const106 StatusOr<uint64_t> BigNum::ToIntValue() const {
107 uint64_t val;
108 if (!BN_get_u64(bn_.get(), &val)) {
109 return InvalidArgumentError("BigNum has more than 64 bits.");
110 }
111 return val;
112 }
113
ToDecimalString() const114 std::string BigNum::ToDecimalString() const {
115 return BnString(BN_bn2dec(GetConstBignumPtr())).ToString();
116 }
117
BitLength() const118 int BigNum::BitLength() const { return BN_num_bits(bn_.get()); }
119
IsPrime(double prime_error_probability) const120 bool BigNum::IsPrime(double prime_error_probability) const {
121 int rounds = static_cast<int>(ceil(-log(prime_error_probability) / log(4)));
122 return (1 == BN_is_prime_ex(bn_.get(), rounds, bn_ctx_, nullptr));
123 }
124
IsSafePrime(double prime_error_probability) const125 bool BigNum::IsSafePrime(double prime_error_probability) const {
126 return IsPrime(prime_error_probability) &&
127 ((*this - BigNum(bn_ctx_, 1)) / BigNum(bn_ctx_, 2))
128 .IsPrime(prime_error_probability);
129 }
130
IsZero() const131 bool BigNum::IsZero() const { return BN_is_zero(bn_.get()); }
132
IsOne() const133 bool BigNum::IsOne() const { return BN_is_one(bn_.get()); }
134
IsNonNegative() const135 bool BigNum::IsNonNegative() const { return !BN_is_negative(bn_.get()); }
136
GetLastNBits(int n) const137 BigNum BigNum::GetLastNBits(int n) const {
138 BigNum r = *this;
139 // Returns 0 on error (if r is already shorter than n bits), but the return
140 // value in that case should be the original value so there is no need to have
141 // error checking here.
142 BN_mask_bits(r.bn_.get(), n);
143 return r;
144 }
145
IsBitSet(int n) const146 bool BigNum::IsBitSet(int n) const { return BN_is_bit_set(bn_.get(), n); }
147
148 // Returns a BigNum whose value is (- *this).
149 // Causes a check failure if the operation fails.
Neg() const150 BigNum BigNum::Neg() const {
151 BigNum r = *this;
152 BN_set_negative(r.bn_.get(), !BN_is_negative(r.bn_.get()));
153 return r;
154 }
155
Add(const BigNum & val) const156 BigNum BigNum::Add(const BigNum& val) const {
157 BigNum r(bn_ctx_);
158 CRYPTO_CHECK(1 == BN_add(r.bn_.get(), bn_.get(), val.bn_.get()));
159 return r;
160 }
161
Mul(const BigNum & val) const162 BigNum BigNum::Mul(const BigNum& val) const {
163 BigNum r(bn_ctx_);
164 CRYPTO_CHECK(1 == BN_mul(r.bn_.get(), bn_.get(), val.bn_.get(), bn_ctx_));
165 return r;
166 }
167
Sub(const BigNum & val) const168 BigNum BigNum::Sub(const BigNum& val) const {
169 BigNum r(bn_ctx_);
170 CRYPTO_CHECK(1 == BN_sub(r.bn_.get(), bn_.get(), val.bn_.get()));
171 return r;
172 }
173
Div(const BigNum & val) const174 BigNum BigNum::Div(const BigNum& val) const {
175 BigNum r(bn_ctx_);
176 BIGNUM* temp = BN_new();
177 CHECK_NE(temp, nullptr);
178 BignumPtr rem(temp);
179 CRYPTO_CHECK(
180 1 == BN_div(r.bn_.get(), rem.get(), bn_.get(), val.bn_.get(), bn_ctx_));
181 CHECK(BN_is_zero(rem.get())) << "Use DivAndTruncate() instead of Div() if "
182 "you want truncated division.";
183 return r;
184 }
185
DivAndTruncate(const BigNum & val) const186 BigNum BigNum::DivAndTruncate(const BigNum& val) const {
187 BigNum r(bn_ctx_);
188 BIGNUM* temp = BN_new();
189 CHECK_NE(temp, nullptr);
190 BignumPtr rem(temp);
191 CRYPTO_CHECK(
192 1 == BN_div(r.bn_.get(), rem.get(), bn_.get(), val.bn_.get(), bn_ctx_));
193 return r;
194 }
195
CompareTo(const BigNum & val) const196 int BigNum::CompareTo(const BigNum& val) const {
197 return BN_cmp(bn_.get(), val.bn_.get());
198 }
199
Exp(const BigNum & exponent) const200 BigNum BigNum::Exp(const BigNum& exponent) const {
201 BigNum r(bn_ctx_);
202 CRYPTO_CHECK(1 ==
203 BN_exp(r.bn_.get(), bn_.get(), exponent.bn_.get(), bn_ctx_));
204 return r;
205 }
206
Mod(const BigNum & m) const207 BigNum BigNum::Mod(const BigNum& m) const {
208 BigNum r(bn_ctx_);
209 CRYPTO_CHECK(1 == BN_nnmod(r.bn_.get(), bn_.get(), m.bn_.get(), bn_ctx_));
210 return r;
211 }
212
ModAdd(const BigNum & val,const BigNum & m) const213 BigNum BigNum::ModAdd(const BigNum& val, const BigNum& m) const {
214 BigNum r(bn_ctx_);
215 CRYPTO_CHECK(1 == BN_mod_add(r.bn_.get(), bn_.get(), val.bn_.get(),
216 m.bn_.get(), bn_ctx_));
217 return r;
218 }
219
ModSub(const BigNum & val,const BigNum & m) const220 BigNum BigNum::ModSub(const BigNum& val, const BigNum& m) const {
221 BigNum r(bn_ctx_);
222 CRYPTO_CHECK(1 == BN_mod_sub(r.bn_.get(), bn_.get(), val.bn_.get(),
223 m.bn_.get(), bn_ctx_));
224 return r;
225 }
226
ModMul(const BigNum & val,const BigNum & m) const227 BigNum BigNum::ModMul(const BigNum& val, const BigNum& m) const {
228 BigNum r(bn_ctx_);
229 CRYPTO_CHECK(1 == BN_mod_mul(r.bn_.get(), bn_.get(), val.bn_.get(),
230 m.bn_.get(), bn_ctx_));
231 return r;
232 }
233
ModExp(const BigNum & exponent,const BigNum & m) const234 BigNum BigNum::ModExp(const BigNum& exponent, const BigNum& m) const {
235 CHECK(exponent.IsNonNegative()) << "Cannot use a negative exponent in BigNum "
236 "ModExp.";
237 BigNum r(bn_ctx_);
238 CRYPTO_CHECK(1 == BN_mod_exp(r.bn_.get(), bn_.get(), exponent.bn_.get(),
239 m.bn_.get(), bn_ctx_));
240 return r;
241 }
242
ModSqr(const BigNum & m) const243 BigNum BigNum::ModSqr(const BigNum& m) const {
244 BigNum r(bn_ctx_);
245 CRYPTO_CHECK(1 == BN_mod_sqr(r.bn_.get(), bn_.get(), m.bn_.get(), bn_ctx_));
246 return r;
247 }
248
ModInverse(const BigNum & m) const249 StatusOr<BigNum> BigNum::ModInverse(const BigNum& m) const {
250 BigNum r(bn_ctx_);
251 if (nullptr == BN_mod_inverse(r.bn_.get(), bn_.get(), m.bn_.get(), bn_ctx_)) {
252 return InvalidArgumentError(
253 absl::StrCat("BigNum::ModInverse failed: ", OpenSSLErrorString()));
254 }
255 return r;
256 }
257
ModSqrt(const BigNum & m) const258 BigNum BigNum::ModSqrt(const BigNum& m) const {
259 BigNum r(bn_ctx_);
260 CRYPTO_CHECK(nullptr !=
261 BN_mod_sqrt(r.bn_.get(), bn_.get(), m.bn_.get(), bn_ctx_));
262 return r;
263 }
264
ModNegate(const BigNum & m) const265 BigNum BigNum::ModNegate(const BigNum& m) const {
266 if (IsZero()) {
267 return *this;
268 }
269 return m - Mod(m);
270 }
271
Lshift(int n) const272 BigNum BigNum::Lshift(int n) const {
273 BigNum r(bn_ctx_);
274 CRYPTO_CHECK(1 == BN_lshift(r.bn_.get(), bn_.get(), n));
275 return r;
276 }
277
Rshift(int n) const278 BigNum BigNum::Rshift(int n) const {
279 BigNum r(bn_ctx_);
280 CRYPTO_CHECK(1 == BN_rshift(r.bn_.get(), bn_.get(), n));
281 return r;
282 }
283
Gcd(const BigNum & val) const284 BigNum BigNum::Gcd(const BigNum& val) const {
285 BigNum r(bn_ctx_);
286 CRYPTO_CHECK(1 == BN_gcd(r.bn_.get(), bn_.get(), val.bn_.get(), bn_ctx_));
287 return r;
288 }
289
290 } // namespace private_join_and_compute
291