xref: /aosp_15_r20/external/private-join-and-compute/private_join_and_compute/crypto/big_num.cc (revision a6aa18fbfbf9cb5cd47356a9d1b057768998488c)
1 /*
2  * Copyright 2019 Google LLC.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     https://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "private_join_and_compute/crypto/big_num.h"
17 
18 #include <cmath>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/string_view.h"
24 #include "private_join_and_compute/crypto/context.h"
25 #include "private_join_and_compute/crypto/openssl.inc"
26 #include "private_join_and_compute/util/status.inc"
27 
28 namespace private_join_and_compute {
29 
30 namespace {
31 
32 // Utility class for decimal string conversion.
33 class BnString {
34  public:
BnString(char * bn_char)35   explicit BnString(char* bn_char) : bn_char_(bn_char) {}
36 
~BnString()37   ~BnString() { OPENSSL_free(bn_char_); }
38 
ToString()39   std::string ToString() { return std::string(bn_char_); }
40 
41  private:
42   char* const bn_char_;
43 };
44 
45 }  // namespace
46 
BigNum(const BigNum & other)47 BigNum::BigNum(const BigNum& other)
48     : bn_(BignumPtr(BN_dup(other.bn_.get()))), bn_ctx_(other.bn_ctx_) {}
49 
operator =(const BigNum & other)50 BigNum& BigNum::operator=(const BigNum& other) {
51   BIGNUM* temp = BN_dup(other.bn_.get());
52   CHECK_NE(temp, nullptr);
53   bn_ = BignumPtr(temp);
54   bn_ctx_ = other.bn_ctx_;
55   return *this;
56 }
57 
BigNum(BigNum && other)58 BigNum::BigNum(BigNum&& other)
59     : bn_(std::move(other.bn_)), bn_ctx_(other.bn_ctx_) {}
60 
operator =(BigNum && other)61 BigNum& BigNum::operator=(BigNum&& other) {
62   bn_ = std::move(other.bn_);
63   bn_ctx_ = other.bn_ctx_;
64   return *this;
65 }
66 
BigNum(BN_CTX * bn_ctx,uint64_t number)67 BigNum::BigNum(BN_CTX* bn_ctx, uint64_t number) : BigNum::BigNum(bn_ctx) {
68   CRYPTO_CHECK(BN_set_u64(bn_.get(), number));
69 }
70 
BigNum(BN_CTX * bn_ctx,absl::string_view bytes)71 BigNum::BigNum(BN_CTX* bn_ctx, absl::string_view bytes)
72     : BigNum::BigNum(bn_ctx) {
73   CRYPTO_CHECK(nullptr !=
74                BN_bin2bn(reinterpret_cast<const unsigned char*>(bytes.data()),
75                          bytes.size(), bn_.get()));
76 }
77 
BigNum(BN_CTX * bn_ctx,const unsigned char * bytes,int length)78 BigNum::BigNum(BN_CTX* bn_ctx, const unsigned char* bytes, int length)
79     : BigNum::BigNum(bn_ctx) {
80   CRYPTO_CHECK(nullptr != BN_bin2bn(bytes, length, bn_.get()));
81 }
82 
BigNum(BN_CTX * bn_ctx)83 BigNum::BigNum(BN_CTX* bn_ctx) {
84   BIGNUM* temp = BN_new();
85   CHECK_NE(temp, nullptr);
86   bn_ = BignumPtr(temp);
87   bn_ctx_ = bn_ctx;
88 }
89 
BigNum(BN_CTX * bn_ctx,BignumPtr bn)90 BigNum::BigNum(BN_CTX* bn_ctx, BignumPtr bn) {
91   bn_ = std::move(bn);
92   bn_ctx_ = bn_ctx;
93 }
94 
GetConstBignumPtr() const95 const BIGNUM* BigNum::GetConstBignumPtr() const { return bn_.get(); }
96 
ToBytes() const97 std::string BigNum::ToBytes() const {
98   CHECK(IsNonNegative()) << "Cannot serialize a negative BigNum.";
99   int length = BN_num_bytes(bn_.get());
100 
101   std::string bytes(length, 0);
102   BN_bn2bin(bn_.get(), reinterpret_cast<unsigned char*>(bytes.data()));
103   return bytes;
104 }
105 
ToIntValue() const106 StatusOr<uint64_t> BigNum::ToIntValue() const {
107   uint64_t val;
108   if (!BN_get_u64(bn_.get(), &val)) {
109     return InvalidArgumentError("BigNum has more than 64 bits.");
110   }
111   return val;
112 }
113 
ToDecimalString() const114 std::string BigNum::ToDecimalString() const {
115   return BnString(BN_bn2dec(GetConstBignumPtr())).ToString();
116 }
117 
BitLength() const118 int BigNum::BitLength() const { return BN_num_bits(bn_.get()); }
119 
IsPrime(double prime_error_probability) const120 bool BigNum::IsPrime(double prime_error_probability) const {
121   int rounds = static_cast<int>(ceil(-log(prime_error_probability) / log(4)));
122   return (1 == BN_is_prime_ex(bn_.get(), rounds, bn_ctx_, nullptr));
123 }
124 
IsSafePrime(double prime_error_probability) const125 bool BigNum::IsSafePrime(double prime_error_probability) const {
126   return IsPrime(prime_error_probability) &&
127          ((*this - BigNum(bn_ctx_, 1)) / BigNum(bn_ctx_, 2))
128              .IsPrime(prime_error_probability);
129 }
130 
IsZero() const131 bool BigNum::IsZero() const { return BN_is_zero(bn_.get()); }
132 
IsOne() const133 bool BigNum::IsOne() const { return BN_is_one(bn_.get()); }
134 
IsNonNegative() const135 bool BigNum::IsNonNegative() const { return !BN_is_negative(bn_.get()); }
136 
GetLastNBits(int n) const137 BigNum BigNum::GetLastNBits(int n) const {
138   BigNum r = *this;
139   // Returns 0 on error (if r is already shorter than n bits), but the return
140   // value in that case should be the original value so there is no need to have
141   // error checking here.
142   BN_mask_bits(r.bn_.get(), n);
143   return r;
144 }
145 
IsBitSet(int n) const146 bool BigNum::IsBitSet(int n) const { return BN_is_bit_set(bn_.get(), n); }
147 
148 // Returns a BigNum whose value is (- *this).
149 // Causes a check failure if the operation fails.
Neg() const150 BigNum BigNum::Neg() const {
151   BigNum r = *this;
152   BN_set_negative(r.bn_.get(), !BN_is_negative(r.bn_.get()));
153   return r;
154 }
155 
Add(const BigNum & val) const156 BigNum BigNum::Add(const BigNum& val) const {
157   BigNum r(bn_ctx_);
158   CRYPTO_CHECK(1 == BN_add(r.bn_.get(), bn_.get(), val.bn_.get()));
159   return r;
160 }
161 
Mul(const BigNum & val) const162 BigNum BigNum::Mul(const BigNum& val) const {
163   BigNum r(bn_ctx_);
164   CRYPTO_CHECK(1 == BN_mul(r.bn_.get(), bn_.get(), val.bn_.get(), bn_ctx_));
165   return r;
166 }
167 
Sub(const BigNum & val) const168 BigNum BigNum::Sub(const BigNum& val) const {
169   BigNum r(bn_ctx_);
170   CRYPTO_CHECK(1 == BN_sub(r.bn_.get(), bn_.get(), val.bn_.get()));
171   return r;
172 }
173 
Div(const BigNum & val) const174 BigNum BigNum::Div(const BigNum& val) const {
175   BigNum r(bn_ctx_);
176   BIGNUM* temp = BN_new();
177   CHECK_NE(temp, nullptr);
178   BignumPtr rem(temp);
179   CRYPTO_CHECK(
180       1 == BN_div(r.bn_.get(), rem.get(), bn_.get(), val.bn_.get(), bn_ctx_));
181   CHECK(BN_is_zero(rem.get())) << "Use DivAndTruncate() instead of Div() if "
182                                   "you want truncated division.";
183   return r;
184 }
185 
DivAndTruncate(const BigNum & val) const186 BigNum BigNum::DivAndTruncate(const BigNum& val) const {
187   BigNum r(bn_ctx_);
188   BIGNUM* temp = BN_new();
189   CHECK_NE(temp, nullptr);
190   BignumPtr rem(temp);
191   CRYPTO_CHECK(
192       1 == BN_div(r.bn_.get(), rem.get(), bn_.get(), val.bn_.get(), bn_ctx_));
193   return r;
194 }
195 
CompareTo(const BigNum & val) const196 int BigNum::CompareTo(const BigNum& val) const {
197   return BN_cmp(bn_.get(), val.bn_.get());
198 }
199 
Exp(const BigNum & exponent) const200 BigNum BigNum::Exp(const BigNum& exponent) const {
201   BigNum r(bn_ctx_);
202   CRYPTO_CHECK(1 ==
203                BN_exp(r.bn_.get(), bn_.get(), exponent.bn_.get(), bn_ctx_));
204   return r;
205 }
206 
Mod(const BigNum & m) const207 BigNum BigNum::Mod(const BigNum& m) const {
208   BigNum r(bn_ctx_);
209   CRYPTO_CHECK(1 == BN_nnmod(r.bn_.get(), bn_.get(), m.bn_.get(), bn_ctx_));
210   return r;
211 }
212 
ModAdd(const BigNum & val,const BigNum & m) const213 BigNum BigNum::ModAdd(const BigNum& val, const BigNum& m) const {
214   BigNum r(bn_ctx_);
215   CRYPTO_CHECK(1 == BN_mod_add(r.bn_.get(), bn_.get(), val.bn_.get(),
216                                m.bn_.get(), bn_ctx_));
217   return r;
218 }
219 
ModSub(const BigNum & val,const BigNum & m) const220 BigNum BigNum::ModSub(const BigNum& val, const BigNum& m) const {
221   BigNum r(bn_ctx_);
222   CRYPTO_CHECK(1 == BN_mod_sub(r.bn_.get(), bn_.get(), val.bn_.get(),
223                                m.bn_.get(), bn_ctx_));
224   return r;
225 }
226 
ModMul(const BigNum & val,const BigNum & m) const227 BigNum BigNum::ModMul(const BigNum& val, const BigNum& m) const {
228   BigNum r(bn_ctx_);
229   CRYPTO_CHECK(1 == BN_mod_mul(r.bn_.get(), bn_.get(), val.bn_.get(),
230                                m.bn_.get(), bn_ctx_));
231   return r;
232 }
233 
ModExp(const BigNum & exponent,const BigNum & m) const234 BigNum BigNum::ModExp(const BigNum& exponent, const BigNum& m) const {
235   CHECK(exponent.IsNonNegative()) << "Cannot use a negative exponent in BigNum "
236                                      "ModExp.";
237   BigNum r(bn_ctx_);
238   CRYPTO_CHECK(1 == BN_mod_exp(r.bn_.get(), bn_.get(), exponent.bn_.get(),
239                                m.bn_.get(), bn_ctx_));
240   return r;
241 }
242 
ModSqr(const BigNum & m) const243 BigNum BigNum::ModSqr(const BigNum& m) const {
244   BigNum r(bn_ctx_);
245   CRYPTO_CHECK(1 == BN_mod_sqr(r.bn_.get(), bn_.get(), m.bn_.get(), bn_ctx_));
246   return r;
247 }
248 
ModInverse(const BigNum & m) const249 StatusOr<BigNum> BigNum::ModInverse(const BigNum& m) const {
250   BigNum r(bn_ctx_);
251   if (nullptr == BN_mod_inverse(r.bn_.get(), bn_.get(), m.bn_.get(), bn_ctx_)) {
252     return InvalidArgumentError(
253         absl::StrCat("BigNum::ModInverse failed: ", OpenSSLErrorString()));
254   }
255   return r;
256 }
257 
ModSqrt(const BigNum & m) const258 BigNum BigNum::ModSqrt(const BigNum& m) const {
259   BigNum r(bn_ctx_);
260   CRYPTO_CHECK(nullptr !=
261                BN_mod_sqrt(r.bn_.get(), bn_.get(), m.bn_.get(), bn_ctx_));
262   return r;
263 }
264 
ModNegate(const BigNum & m) const265 BigNum BigNum::ModNegate(const BigNum& m) const {
266   if (IsZero()) {
267     return *this;
268   }
269   return m - Mod(m);
270 }
271 
Lshift(int n) const272 BigNum BigNum::Lshift(int n) const {
273   BigNum r(bn_ctx_);
274   CRYPTO_CHECK(1 == BN_lshift(r.bn_.get(), bn_.get(), n));
275   return r;
276 }
277 
Rshift(int n) const278 BigNum BigNum::Rshift(int n) const {
279   BigNum r(bn_ctx_);
280   CRYPTO_CHECK(1 == BN_rshift(r.bn_.get(), bn_.get(), n));
281   return r;
282 }
283 
Gcd(const BigNum & val) const284 BigNum BigNum::Gcd(const BigNum& val) const {
285   BigNum r(bn_ctx_);
286   CRYPTO_CHECK(1 == BN_gcd(r.bn_.get(), bn_.get(), val.bn_.get(), bn_ctx_));
287   return r;
288 }
289 
290 }  // namespace private_join_and_compute
291