xref: /aosp_15_r20/external/private-join-and-compute/private_join_and_compute/crypto/paillier.cc (revision a6aa18fbfbf9cb5cd47356a9d1b057768998488c)
1 /*
2  * Copyright 2019 Google LLC.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     https://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "private_join_and_compute/crypto/paillier.h"
17 
18 #include <stddef.h>
19 
20 #include <memory>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/container/node_hash_map.h"
25 #include "absl/log/check.h"
26 #include "absl/log/log.h"
27 #include "private_join_and_compute/crypto/big_num.h"
28 #include "private_join_and_compute/crypto/context.h"
29 #include "private_join_and_compute/crypto/fixed_base_exp.h"
30 #include "private_join_and_compute/crypto/two_modulus_crt.h"
31 #include "private_join_and_compute/util/status.inc"
32 
33 namespace private_join_and_compute {
34 
35 namespace {
36 // The number of times to iteratively try to find a generator for a safe prime
37 // starting from the candidate, 2.
38 constexpr int32_t kGeneratorTryCount = 1000;
39 }  // namespace
40 
41 // A class representing a table of BigNums.
42 // The column length of the table is fixed and given in the constructor.
43 // Example:
44 //   // Given BigNum a;
45 //   BigNumTable table(5);
46 //   table.Insert(2, 3, a);
47 //   BigNum b = table.Get(2, 3)  // returns the same copy of BigNum a each time
48 //                               // Get is called with the same parameters.
49 //
50 // Note that while a two-dimensional vector can be used in place of this class,
51 // this is more versatile in the case of partially filled tables.
52 class BigNumTable {
53  public:
54   // Creates a BigNumTable with a fixed column length.
BigNumTable(size_t column_length)55   explicit BigNumTable(size_t column_length)
56       : column_length_(column_length), table_() {}
57 
58   // Inserts a copy of num into x, y cell of the table.
Insert(int x,int y,const BigNum & num)59   void Insert(int x, int y, const BigNum& num) {
60     CHECK_LT(y, column_length_);
61     table_.insert(std::make_pair(x * column_length_ + y, num));
62   }
63 
64   // Returns a reference to the BigNum at x, y cell.
65   // Note that this object must outlive the scope of whoever called this
66   // function so that the returned reference stays valid.
Get(int x,int y) const67   const BigNum& Get(int x, int y) const {
68     CHECK_LT(y, column_length_);
69     auto iter = table_.find(x * column_length_ + y);
70     if (iter == table_.end()) {
71       LOG(FATAL) << "The element at x = " << x << " and y = " << y
72                  << " does not exist";
73     }
74     return iter->second;
75   }
76 
77  private:
78   const size_t column_length_;
79   absl::node_hash_map<int, BigNum> table_;
80 };
81 
82 namespace {
83 
84 // Returns a BigNum, g, that is a generator for the Zp*.
GetGeneratorForSafePrime(Context * ctx,const BigNum & p)85 BigNum GetGeneratorForSafePrime(Context* ctx, const BigNum& p) {
86   CHECK(p.IsSafePrime());
87   BigNum q = (p - ctx->One()) / ctx->Two();
88   BigNum g = ctx->CreateBigNum(2);
89   for (int32_t i = 0; i < kGeneratorTryCount; i++) {
90     if (g.ModSqr(p).IsOne() || g.ModExp(q, p).IsOne()) {
91       g = g + ctx->One();
92     } else {
93       return g;
94     }
95   }
96   // Just in case IsSafePrime is not correct.
97   LOG(FATAL) << "Either try_count is insufficient or p is not a safe prime."
98              << " generator_try_count: " << kGeneratorTryCount;
99 }
100 
101 // Returns a BigNum, g, that is a generator for Zn*, where n is the product
102 // of 2 safe primes.
GetGeneratorForSafeModulus(Context * ctx,const BigNum & n)103 BigNum GetGeneratorForSafeModulus(Context* ctx, const BigNum& n) {
104   // As explained in Damgard-Jurik-Nielsen, if n is the product of safe primes,
105   // it is sufficient to choose a random number x in Z*n and return
106   // g = -(x^2) mod n
107   BigNum x = ctx->RelativelyPrimeRandomLessThan(n);
108   return n - x.ModSqr(n);
109 }
110 
111 // Returns a BigNum, g, that is a generator for Zp^t* for any t > 1.
GetGeneratorOfPrimePowersFromSafePrime(Context * ctx,const BigNum & p)112 BigNum GetGeneratorOfPrimePowersFromSafePrime(Context* ctx, const BigNum& p) {
113   BigNum g = GetGeneratorForSafePrime(ctx, p);
114   if (g.ModExp(p - ctx->One(), p * p).IsOne()) {
115     return g + p;
116   }
117   return g;
118 }
119 
120 // Returns a vector of num^i for i in [0, s + 1].
GetPowers(Context * ctx,const BigNum & num,int s)121 std::vector<BigNum> GetPowers(Context* ctx, const BigNum& num, int s) {
122   std::vector<BigNum> powers;
123   powers.push_back(ctx->CreateBigNum(1));
124   for (int i = 1; i <= s + 1; i++) {
125     powers.push_back(powers.back().Mul(num));
126   }
127   return powers;
128 }
129 
130 // Returns a vector of (1 / (i!)) * n^i mod n^(s+1) for i in [0, s].
GetPrecomp(Context * ctx,const BigNum & num,const BigNum & modulus,int s)131 std::vector<BigNum> GetPrecomp(Context* ctx, const BigNum& num,
132                                const BigNum& modulus, int s) {
133   std::vector<BigNum> precomp;
134   precomp.push_back(ctx->CreateBigNum(1));
135   for (int i = 1; i <= s; i++) {
136     BigNum i_inv = ctx->CreateBigNum(i).ModInverse(modulus).value();
137     BigNum i_inv_n = i_inv.ModMul(num, modulus);
138     precomp.push_back(precomp.back().ModMul(i_inv_n, modulus));
139   }
140   return precomp;
141 }
142 
143 // Returns a vector of (1 / (k!)) * n^(k - 1) mod p^j for 2 <= k <= j <= s.
144 // Reuses the values from GetPrecomp function output, precomp.
GetDecryptPrecomp(Context * ctx,const std::vector<BigNum> & precomp,const std::vector<BigNum> & powers,int s)145 std::unique_ptr<BigNumTable> GetDecryptPrecomp(
146     Context* ctx, const std::vector<BigNum>& precomp,
147     const std::vector<BigNum>& powers, int s) {
148   // The first index is k and the second one is j from the Theorem 1 algorithm
149   // of Damgaard-Jurik-Nielsen paper.
150   // The table indices are [2, s] in each dimension with the following
151   // structure:
152   //     j
153   //  +-----+
154   //   -----|
155   //    ----|  k
156   //     ---|
157   //      --|
158   //       -+
159   std::unique_ptr<BigNumTable> precomp_table(new BigNumTable(s + 1));
160   for (int k = 2; k <= s; k++) {
161     BigNum k_inverse = ctx->CreateBigNum(k).ModInverse(powers[s]).value();
162     precomp_table->Insert(k, s, k_inverse.ModMul(precomp[k - 1], powers[s]));
163     for (int j = s - 1; j >= k; j--) {
164       precomp_table->Insert(k, j, precomp_table->Get(k, j + 1).Mod(powers[j]));
165     }
166   }
167   return precomp_table;
168 }
169 
170 // Computes (1 + powers[1])^message via binomial expansion (message=m):
171 // 1 + mn + C(m, 2)n^2 + ... + C(m, s)n^s mod n^(s + 1)
ComputeByBinomialExpansion(Context * ctx,const std::vector<BigNum> & precomp,const std::vector<BigNum> & powers,const BigNum & message)172 BigNum ComputeByBinomialExpansion(Context* ctx,
173                                   const std::vector<BigNum>& precomp,
174                                   const std::vector<BigNum>& powers,
175                                   const BigNum& message) {
176   // Refer to Section 4.2 Optimizations of Encryption from the Damgaard-Jurik
177   // cryptosystem paper.
178   BigNum c = ctx->CreateBigNum(1);
179   BigNum tmp = ctx->CreateBigNum(1);
180   const int s = precomp.size() - 1;
181   BigNum reduced_message = message.Mod(powers[s]);
182   for (int j = 1; j <= s; j++) {
183     const BigNum& j_bn = ctx->CreateBigNum(j);
184     if (reduced_message < j_bn) {
185       break;
186     }
187     tmp = tmp.ModMul(reduced_message - j_bn + ctx->One(), powers[s - j + 1]);
188     c = c + tmp.ModMul(precomp[j], powers[s + 1]);
189   }
190   return c;
191 }
192 
193 }  // namespace
194 
195 StatusOr<std::pair<PaillierPublicKey, PaillierPrivateKey>>
GeneratePaillierKeyPair(Context * ctx,int32_t modulus_length,int32_t s)196 GeneratePaillierKeyPair(Context* ctx, int32_t modulus_length, int32_t s) {
197   if (modulus_length / 2 <= 0 || s <= 0) {
198     return InvalidArgumentError(
199         "GeneratePaillierKeyPair: modulus_length/2 and s must each be >0");
200   }
201 
202   BigNum p = ctx->GenerateSafePrime(modulus_length / 2);
203   BigNum q = ctx->GenerateSafePrime(modulus_length / 2);
204   while (p == q) {
205     q = ctx->GenerateSafePrime(modulus_length / 2);
206   }
207   BigNum n = p * q;
208 
209   PaillierPrivateKey private_key;
210   private_key.set_p(p.ToBytes());
211   private_key.set_q(q.ToBytes());
212   private_key.set_s(s);
213 
214   PaillierPublicKey public_key;
215   public_key.set_n(n.ToBytes());
216   public_key.set_s(s);
217 
218   return std::make_pair(std::move(public_key), std::move(private_key));
219 }
220 
221 // A helper class defining Encrypt and Decrypt for only one of the prime parts
222 // of the composite number n. Computing (1+n)^m * g^r mod p^(s+1) where r is in
223 // [1, p) for both p and q and then computing CRT yields a result with the same
224 // randomness as computing (1+n)^m * random^(n^s) mod n^(s+1) whereas the former
225 // is much faster as the modulus length is half the size of n for each step.
226 //
227 // This class is not thread-safe since Context is not thread-safe.
228 // Note that this does *not* take the ownership of Context.
229 class PrimeCrypto {
230  public:
231   // Creates a PrimeCrypto with the given parameter where p and other_prime is
232   // either <p, q> or <q, p>.
PrimeCrypto(Context * ctx,const BigNum & p,const BigNum & other_prime,int s)233   PrimeCrypto(Context* ctx, const BigNum& p, const BigNum& other_prime, int s)
234       : ctx_(ctx),
235         p_(p),
236         p_phi_(p - ctx->One()),
237         n_(p * other_prime),
238         s_(s),
239         powers_(GetPowers(ctx, p, s)),
240         precomp_(GetPrecomp(ctx, n_, powers_[s + 1], s)),
241         lambda_inv_(p_phi_.ModInverse(powers_[s_]).value()),
242         other_prime_inv_(other_prime.ModInverse(powers_[s]).value()),
243         decrypt_precomp_(GetDecryptPrecomp(ctx, precomp_, powers_, s)),
244         g_p_(GetGeneratorOfPrimePowersFromSafePrime(ctx, p)),
245         fbe_(FixedBaseExp::GetFixedBaseExp(
246             ctx, g_p_.ModExp(n_.Exp(ctx->CreateBigNum(s)), powers_[s + 1]),
247             powers_[s + 1])) {}
248 
249   // PrimeCrypto is neither copyable nor movable.
250   PrimeCrypto(const PrimeCrypto&) = delete;
251   PrimeCrypto& operator=(const PrimeCrypto&) = delete;
252 
253   // Computes (1+n)^m * g^r mod p^(s+1) where r is in [1, p).
Encrypt(const BigNum & m) const254   StatusOr<BigNum> Encrypt(const BigNum& m) const {
255     return EncryptWithRand(m, ctx_->GenerateRandBetween(ctx_->One(), p_));
256   }
257 
258   // Encrypts the message similar to other Encrypt method, but uses the input
259   // random value. (The caller has responsibility to ensure the randomness of
260   // the value.)
EncryptWithRand(const BigNum & m,const BigNum & r) const261   StatusOr<BigNum> EncryptWithRand(const BigNum& m, const BigNum& r) const {
262     BigNum c_p = ComputeByBinomialExpansion(ctx_, precomp_, powers_, m);
263     ASSIGN_OR_RETURN(BigNum g_to_r, fbe_->ModExp(r));
264     return c_p.ModMul(g_to_r, powers_[s_ + 1]);
265   }
266 
267   // Decrypts c for this prime part so that computing CRT with the other prime
268   // decryption yields to the original message inside this ciphertext.
Decrypt(const BigNum & c) const269   BigNum Decrypt(const BigNum& c) const {
270     // Theorem 1 algorithm from Damgaard-Jurik-Nielsen paper.
271     // Cancels out the random portion and compute the L function.
272     BigNum l_u = LFunc(c.ModExp(p_phi_, powers_[s_ + 1]));
273     BigNum m_lambda = ctx_->CreateBigNum(0);
274     for (int j = 1; j <= s_; j++) {
275       BigNum t1 = l_u.Mod(powers_[j]);
276       BigNum t2 = m_lambda;
277       for (int k = 2; k <= j; k++) {
278         m_lambda = m_lambda - ctx_->One();
279         t2 = t2.ModMul(m_lambda, powers_[j]);
280         t1 = t1 - t2 * decrypt_precomp_->Get(k, j);
281       }
282       m_lambda = std::move(t1);
283     }
284     return m_lambda.ModMul(lambda_inv_, powers_[s_]);
285   }
286 
287   // Returns p^i from the cache.
GetPToExp(int i) const288   const BigNum& GetPToExp(int i) const { return powers_[i]; }
289 
290  private:
291   friend class PrimeCryptoWithRand;
292   // Paillier L function modified to work on prime parts. Refer to the
293   // subsection "Decryption" under Section 4.2 "Optimizations of Encryption"
294   // from the Damgaard-Jurik cryptosystem paper.
LFunc(const BigNum & c_mod_p_to_s_plus_one) const295   BigNum LFunc(const BigNum& c_mod_p_to_s_plus_one) const {
296     return ((c_mod_p_to_s_plus_one - ctx_->One()) / p_)
297         .ModMul(other_prime_inv_, GetPToExp(s_));
298   }
299 
300   Context* const ctx_;
301   const BigNum p_;
302   const BigNum p_phi_;
303   const BigNum n_;
304   const int s_;
305   const std::vector<BigNum> powers_;
306   const std::vector<BigNum> precomp_;
307   const BigNum lambda_inv_;
308   const BigNum other_prime_inv_;
309   const std::unique_ptr<BigNumTable> decrypt_precomp_;
310   const BigNum g_p_;
311   std::unique_ptr<FixedBaseExp> fbe_;
312 };
313 
314 // Class that wraps a PrimeCrypto, and additionally can return the random number
315 // (used in an encryption) with the ciphertext.
316 class PrimeCryptoWithRand {
317  public:
PrimeCryptoWithRand(PrimeCrypto * prime_crypto)318   explicit PrimeCryptoWithRand(PrimeCrypto* prime_crypto)
319       : ctx_(prime_crypto->ctx_),
320         prime_crypto_(prime_crypto),
321         exp_for_report_(FixedBaseExp::GetFixedBaseExp(
322             ctx_, prime_crypto_->g_p_,
323             prime_crypto_->GetPToExp(prime_crypto_->s_ + 1))) {}
324 
325   // PrimeCryptoWithRand is neither copyable nor movable.
326   PrimeCryptoWithRand(const PrimeCryptoWithRand&) = delete;
327   PrimeCryptoWithRand& operator=(const PrimeCryptoWithRand&) = delete;
328 
329   // Encrypts the message and returns the result the same way as in PrimeCrypto.
Encrypt(const BigNum & m) const330   StatusOr<BigNum> Encrypt(const BigNum& m) const {
331     return prime_crypto_->Encrypt(m);
332   }
333 
334   // Encrypts the message with the input random value the same way as in
335   // PrimeCrypto.
EncryptWithRand(const BigNum & m,const BigNum & r) const336   StatusOr<BigNum> EncryptWithRand(const BigNum& m, const BigNum& r) const {
337     return prime_crypto_->EncryptWithRand(m, r);
338   }
339 
340   // Encrypts the message the same way as in PrimeCrypto, and returns the
341   // random used.
EncryptAndGetRand(const BigNum & m) const342   StatusOr<PaillierEncAndRand> EncryptAndGetRand(const BigNum& m) const {
343     BigNum r = ctx_->GenerateRandBetween(ctx_->One(), prime_crypto_->p_);
344     ASSIGN_OR_RETURN(BigNum ct, EncryptWithRand(m, r));
345     ASSIGN_OR_RETURN(BigNum exp_for_report_to_r, exp_for_report_->ModExp(r));
346     return {{std::move(ct), std::move(exp_for_report_to_r)}};
347   }
348 
349   // Decrypts the ciphertext the same way as in PrimeCrypto.
Decrypt(const BigNum & c) const350   BigNum Decrypt(const BigNum& c) const { return prime_crypto_->Decrypt(c); }
351 
352  private:
353   Context* const ctx_;
354   const PrimeCrypto* const prime_crypto_;
355   std::unique_ptr<FixedBaseExp> exp_for_report_;
356 };
357 
358 static const int kDefaultS = 1;
359 
PublicPaillier(Context * ctx,const BigNum & n,int s)360 PublicPaillier::PublicPaillier(Context* ctx, const BigNum& n, int s)
361     : ctx_(ctx),
362       n_(n),
363       s_(s),
364       n_powers_(GetPowers(ctx, n_, s)),
365       modulus_(n_powers_.back()),
366       g_n_fbe_(FixedBaseExp::GetFixedBaseExp(
367           ctx,
368           GetGeneratorForSafeModulus(ctx_, n).ModExp(n_powers_[s], modulus_),
369           modulus_)),
370       precomp_(GetPrecomp(ctx, n_, modulus_, s)) {}
371 
PublicPaillier(Context * ctx,const BigNum & n)372 PublicPaillier::PublicPaillier(Context* ctx, const BigNum& n)
373     : PublicPaillier(ctx, n, kDefaultS) {}
374 
PublicPaillier(Context * ctx,const PaillierPublicKey & public_key_proto)375 PublicPaillier::PublicPaillier(Context* ctx,
376                                const PaillierPublicKey& public_key_proto)
377     : PublicPaillier(ctx, ctx->CreateBigNum(public_key_proto.n()),
378                      public_key_proto.s()) {}
379 
380 PublicPaillier::~PublicPaillier() = default;
381 
Add(const BigNum & ciphertext1,const BigNum & ciphertext2) const382 BigNum PublicPaillier::Add(const BigNum& ciphertext1,
383                            const BigNum& ciphertext2) const {
384   return ciphertext1.ModMul(ciphertext2, modulus_);
385 }
386 
Multiply(const BigNum & c,const BigNum & m) const387 BigNum PublicPaillier::Multiply(const BigNum& c, const BigNum& m) const {
388   return c.ModExp(m, modulus_);
389 }
390 
LeftShift(const BigNum & c,int shift_amount) const391 BigNum PublicPaillier::LeftShift(const BigNum& c, int shift_amount) const {
392   return Multiply(c, ctx_->One().Lshift(shift_amount));
393 }
394 
Encrypt(const BigNum & m) const395 StatusOr<BigNum> PublicPaillier::Encrypt(const BigNum& m) const {
396   if (!m.IsNonNegative()) {
397     return InvalidArgumentError(
398         "PublicPaillier::Encrypt() - Cannot encrypt negative number.");
399   }
400   if (m >= n_powers_[s_]) {
401     return InvalidArgumentError(
402         "PublicPaillier::Encrypt() - Message not smaller than n^s.");
403   }
404   return EncryptUsingGeneratorAndRand(m, ctx_->GenerateRandLessThan(n_));
405 }
406 
EncryptUsingGeneratorAndRand(const BigNum & m,const BigNum & r) const407 StatusOr<BigNum> PublicPaillier::EncryptUsingGeneratorAndRand(
408     const BigNum& m, const BigNum& r) const {
409   if (r > n_) {
410     return InvalidArgumentError(
411         "PublicPaillier: The given random is not less than or equal to n.");
412   }
413   BigNum c = ComputeByBinomialExpansion(ctx_, precomp_, n_powers_, m);
414   ASSIGN_OR_RETURN(BigNum g_n_to_r, g_n_fbe_->ModExp(r));
415   return c.ModMul(g_n_to_r, modulus_);
416 }
417 
EncryptWithRand(const BigNum & m,const BigNum & r) const418 StatusOr<BigNum> PublicPaillier::EncryptWithRand(const BigNum& m,
419                                                  const BigNum& r) const {
420   if (r.Gcd(n_) != ctx_->One()) {
421     return InvalidArgumentError(
422         "PublicPaillier::EncryptWithRand: The given random is not in Z*n.");
423   }
424   BigNum c = ComputeByBinomialExpansion(ctx_, precomp_, n_powers_, m);
425   return c.ModMul(r.ModExp(n_powers_[s_], modulus_), modulus_);
426 }
427 
EncryptAndGetRand(const BigNum & m) const428 StatusOr<PaillierEncAndRand> PublicPaillier::EncryptAndGetRand(
429     const BigNum& m) const {
430   BigNum r = ctx_->RelativelyPrimeRandomLessThan(n_);
431   ASSIGN_OR_RETURN(BigNum c, EncryptWithRand(m, r));
432   return {{std::move(c), std::move(r)}};
433 }
434 
435 PrivatePaillier::~PrivatePaillier() = default;
436 
PrivatePaillier(Context * ctx,const BigNum & p,const BigNum & q,int s)437 PrivatePaillier::PrivatePaillier(Context* ctx, const BigNum& p, const BigNum& q,
438                                  int s)
439     : ctx_(ctx),
440       n_to_s_((p * q).Exp(ctx_->CreateBigNum(s))),
441       n_to_s_plus_one_(n_to_s_ * p * q),
442       p_crypto_(new PrimeCrypto(ctx, p, q, s)),
443       q_crypto_(new PrimeCrypto(ctx, q, p, s)),
444       two_mod_crt_encrypt_(new TwoModulusCrt(p_crypto_->GetPToExp(s + 1),
445                                              q_crypto_->GetPToExp(s + 1))),
446       two_mod_crt_decrypt_(new TwoModulusCrt(p_crypto_->GetPToExp(s),
447                                              q_crypto_->GetPToExp(s))) {}
448 
PrivatePaillier(Context * ctx,const PaillierPrivateKey & private_key_proto)449 PrivatePaillier::PrivatePaillier(Context* ctx,
450                                  const PaillierPrivateKey& private_key_proto)
451     : PrivatePaillier(ctx, ctx->CreateBigNum(private_key_proto.p()),
452                       ctx->CreateBigNum(private_key_proto.q()),
453                       private_key_proto.s()) {}
454 
Encrypt(const BigNum & m) const455 StatusOr<BigNum> PrivatePaillier::Encrypt(const BigNum& m) const {
456   if (!m.IsNonNegative()) {
457     return InvalidArgumentError(
458         "PrivatePaillier::Encrypt() - Cannot encrypt negative number.");
459   }
460   if (m >= n_to_s_) {
461     return InvalidArgumentError(
462         "PrivatePaillier::Encrypt() - Message not smaller than n^s.");
463   }
464   ASSIGN_OR_RETURN(BigNum p_ct, p_crypto_->Encrypt(m));
465   ASSIGN_OR_RETURN(BigNum q_ct, q_crypto_->Encrypt(m));
466   return two_mod_crt_encrypt_->Compute(p_ct, q_ct);
467 }
468 
PrivatePaillier(Context * ctx,const BigNum & p,const BigNum & q)469 PrivatePaillier::PrivatePaillier(Context* ctx, const BigNum& p, const BigNum& q)
470     : PrivatePaillier(ctx, p, q, kDefaultS) {}
471 
Decrypt(const BigNum & c) const472 StatusOr<BigNum> PrivatePaillier::Decrypt(const BigNum& c) const {
473   if (!c.IsNonNegative()) {
474     return InvalidArgumentError(
475         "PrivatePaillier::Decrypt() - Cannot decrypt negative number.");
476   }
477   if (c >= n_to_s_plus_one_) {
478     return InvalidArgumentError(
479         "PrivatePaillier::Decrypt() - Ciphertext not smaller than n^(s+1).");
480   }
481   return two_mod_crt_decrypt_->Compute(p_crypto_->Decrypt(c),
482                                        q_crypto_->Decrypt(c));
483 }
484 
PrivatePaillierWithRand(PrivatePaillier * private_paillier)485 PrivatePaillierWithRand::PrivatePaillierWithRand(
486     PrivatePaillier* private_paillier)
487     : ctx_(private_paillier->ctx_), private_paillier_(private_paillier) {
488   const BigNum& p = private_paillier_->p_crypto_->GetPToExp(1);
489   const BigNum& q = private_paillier_->q_crypto_->GetPToExp(1);
490   two_mod_crt_rand_ = std::make_unique<TwoModulusCrt>(p, q);
491   p_crypto_ =
492       std::make_unique<PrimeCryptoWithRand>(private_paillier_->p_crypto_.get());
493   q_crypto_ =
494       std::make_unique<PrimeCryptoWithRand>(private_paillier_->q_crypto_.get());
495 }
496 
497 PrivatePaillierWithRand::~PrivatePaillierWithRand() = default;
498 
Encrypt(const BigNum & m) const499 StatusOr<BigNum> PrivatePaillierWithRand::Encrypt(const BigNum& m) const {
500   return private_paillier_->Encrypt(m);
501 }
502 
EncryptAndGetRand(const BigNum & m) const503 StatusOr<PaillierEncAndRand> PrivatePaillierWithRand::EncryptAndGetRand(
504     const BigNum& m) const {
505   if (!m.IsNonNegative()) {
506     return InvalidArgumentError(
507         "PrivatePaillier::Encrypt() - Cannot encrypt negative number.");
508   }
509   if (m >= private_paillier_->n_to_s_) {
510     return InvalidArgumentError(
511         "PrivatePaillier::Encrypt() - Message not smaller than n^s.");
512   }
513 
514   ASSIGN_OR_RETURN(const PaillierEncAndRand enc_p,
515                    p_crypto_->EncryptAndGetRand(m));
516   ASSIGN_OR_RETURN(const PaillierEncAndRand enc_q,
517                    q_crypto_->EncryptAndGetRand(m));
518 
519   BigNum c = private_paillier_->two_mod_crt_encrypt_->Compute(enc_p.ciphertext,
520                                                               enc_q.ciphertext);
521   BigNum r = two_mod_crt_rand_->Compute(enc_p.rand, enc_q.rand);
522   return {{std::move(c), std::move(r)}};
523 }
524 
Decrypt(const BigNum & c) const525 StatusOr<BigNum> PrivatePaillierWithRand::Decrypt(const BigNum& c) const {
526   return private_paillier_->Decrypt(c);
527 }
528 
529 }  // namespace private_join_and_compute
530