xref: /aosp_15_r20/external/federated-compute/fcp/secagg/server/aes/aes_secagg_server_protocol_impl.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     https://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FCP_SECAGG_SERVER_AES_AES_SECAGG_SERVER_PROTOCOL_IMPL_H_
18 #define FCP_SECAGG_SERVER_AES_AES_SECAGG_SERVER_PROTOCOL_IMPL_H_
19 
20 #include <functional>
21 #include <memory>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "fcp/secagg/server/secagg_scheduler.h"
27 #include "fcp/secagg/server/secagg_server_enums.pb.h"
28 #include "fcp/secagg/server/secagg_server_protocol_impl.h"
29 #include "fcp/secagg/server/tracing_schema.h"
30 #include "fcp/secagg/shared/secagg_vector.h"
31 #include "fcp/tracing/tracing_span.h"
32 
33 namespace fcp {
34 namespace secagg {
35 
36 class AesSecAggServerProtocolImpl
37     : public SecAggServerProtocolImpl,
38       public std::enable_shared_from_this<AesSecAggServerProtocolImpl> {
39  public:
40   AesSecAggServerProtocolImpl(
41       std::unique_ptr<SecretSharingGraph> graph,
42       int minimum_number_of_clients_to_proceed,
43       std::vector<InputVectorSpecification> input_vector_specs,
44       std::unique_ptr<SecAggServerMetricsListener> metrics,
45       std::unique_ptr<AesPrngFactory> prng_factory,
46       SendToClientsInterface* sender,
47       std::unique_ptr<SecAggScheduler> scheduler,
48       std::vector<ClientStatus> client_statuses, ServerVariant server_variant,
49       std::unique_ptr<ExperimentsInterface> experiments = nullptr)
SecAggServerProtocolImpl(std::move (graph),minimum_number_of_clients_to_proceed,std::move (metrics),std::move (prng_factory),sender,std::move (scheduler),std::move (client_statuses),std::move (experiments))50       : SecAggServerProtocolImpl(
51             std::move(graph), minimum_number_of_clients_to_proceed,
52             std::move(metrics), std::move(prng_factory), sender,
53             std::move(scheduler), std::move(client_statuses),
54             std::move(experiments)),
55         server_variant_(server_variant),
56         input_vector_specs_(std::move(input_vector_specs)) {}
57 
server_variant()58   ServerVariant server_variant() const override { return server_variant_; }
59 
60   // Returns one InputVectorSpecification for each input vector which the
61   // protocol will aggregate.
input_vector_specs()62   inline const std::vector<InputVectorSpecification>& input_vector_specs()
63       const {
64     return input_vector_specs_;
65   }
66 
InitializeShareKeysRequest(ShareKeysRequest * request)67   Status InitializeShareKeysRequest(ShareKeysRequest* request) const override {
68     return ::absl::OkStatus();
69   }
70 
71   // TODO(team): Remove this method. This field must be set from
72   // inside the protocol implementation.
set_masked_input(std::unique_ptr<SecAggUnpackedVectorMap> masked_input)73   void set_masked_input(std::unique_ptr<SecAggUnpackedVectorMap> masked_input) {
74     masked_input_ = std::move(masked_input);
75   }
76 
77   // Takes out ownership the accumulated queue of masked inputs and empties
78   // the current queue.
79   std::vector<std::unique_ptr<SecAggVectorMap>> TakeMaskedInputQueue();
80 
81   std::shared_ptr<Accumulator<SecAggUnpackedVectorMap>>
82   SetupMaskedInputCollection() override;
83 
84   void FinalizeMaskedInputCollection() override;
85 
86   Status HandleMaskedInputCollectionResponse(
87       std::unique_ptr<MaskedInputCollectionResponse> masked_input_response)
88       override;
89 
90   CancellationToken StartPrng(
91       const PrngWorkItems& work_items,
92       std::function<void(Status)> done_callback) override;
93 
94  private:
95   std::unique_ptr<SecAggUnpackedVectorMap> masked_input_;
96   // Protects masked_input_queue_.
97   absl::Mutex mutex_;
98   std::vector<std::unique_ptr<SecAggVectorMap>> masked_input_queue_
99       ABSL_GUARDED_BY(mutex_);
100   std::shared_ptr<Accumulator<SecAggUnpackedVectorMap>>
101       masked_input_accumulator_;
102   ServerVariant server_variant_;
103   std::vector<InputVectorSpecification> input_vector_specs_;
104 };
105 
106 }  // namespace secagg
107 }  // namespace fcp
108 
109 #endif  // FCP_SECAGG_SERVER_AES_AES_SECAGG_SERVER_PROTOCOL_IMPL_H_
110