1 /* 2 * Copyright 2020 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * https://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef FCP_SECAGG_SERVER_AES_AES_SECAGG_SERVER_PROTOCOL_IMPL_H_ 18 #define FCP_SECAGG_SERVER_AES_AES_SECAGG_SERVER_PROTOCOL_IMPL_H_ 19 20 #include <functional> 21 #include <memory> 22 #include <utility> 23 #include <vector> 24 25 #include "absl/container/flat_hash_map.h" 26 #include "fcp/secagg/server/secagg_scheduler.h" 27 #include "fcp/secagg/server/secagg_server_enums.pb.h" 28 #include "fcp/secagg/server/secagg_server_protocol_impl.h" 29 #include "fcp/secagg/server/tracing_schema.h" 30 #include "fcp/secagg/shared/secagg_vector.h" 31 #include "fcp/tracing/tracing_span.h" 32 33 namespace fcp { 34 namespace secagg { 35 36 class AesSecAggServerProtocolImpl 37 : public SecAggServerProtocolImpl, 38 public std::enable_shared_from_this<AesSecAggServerProtocolImpl> { 39 public: 40 AesSecAggServerProtocolImpl( 41 std::unique_ptr<SecretSharingGraph> graph, 42 int minimum_number_of_clients_to_proceed, 43 std::vector<InputVectorSpecification> input_vector_specs, 44 std::unique_ptr<SecAggServerMetricsListener> metrics, 45 std::unique_ptr<AesPrngFactory> prng_factory, 46 SendToClientsInterface* sender, 47 std::unique_ptr<SecAggScheduler> scheduler, 48 std::vector<ClientStatus> client_statuses, ServerVariant server_variant, 49 std::unique_ptr<ExperimentsInterface> experiments = nullptr) SecAggServerProtocolImpl(std::move (graph),minimum_number_of_clients_to_proceed,std::move (metrics),std::move (prng_factory),sender,std::move (scheduler),std::move (client_statuses),std::move (experiments))50 : SecAggServerProtocolImpl( 51 std::move(graph), minimum_number_of_clients_to_proceed, 52 std::move(metrics), std::move(prng_factory), sender, 53 std::move(scheduler), std::move(client_statuses), 54 std::move(experiments)), 55 server_variant_(server_variant), 56 input_vector_specs_(std::move(input_vector_specs)) {} 57 server_variant()58 ServerVariant server_variant() const override { return server_variant_; } 59 60 // Returns one InputVectorSpecification for each input vector which the 61 // protocol will aggregate. input_vector_specs()62 inline const std::vector<InputVectorSpecification>& input_vector_specs() 63 const { 64 return input_vector_specs_; 65 } 66 InitializeShareKeysRequest(ShareKeysRequest * request)67 Status InitializeShareKeysRequest(ShareKeysRequest* request) const override { 68 return ::absl::OkStatus(); 69 } 70 71 // TODO(team): Remove this method. This field must be set from 72 // inside the protocol implementation. set_masked_input(std::unique_ptr<SecAggUnpackedVectorMap> masked_input)73 void set_masked_input(std::unique_ptr<SecAggUnpackedVectorMap> masked_input) { 74 masked_input_ = std::move(masked_input); 75 } 76 77 // Takes out ownership the accumulated queue of masked inputs and empties 78 // the current queue. 79 std::vector<std::unique_ptr<SecAggVectorMap>> TakeMaskedInputQueue(); 80 81 std::shared_ptr<Accumulator<SecAggUnpackedVectorMap>> 82 SetupMaskedInputCollection() override; 83 84 void FinalizeMaskedInputCollection() override; 85 86 Status HandleMaskedInputCollectionResponse( 87 std::unique_ptr<MaskedInputCollectionResponse> masked_input_response) 88 override; 89 90 CancellationToken StartPrng( 91 const PrngWorkItems& work_items, 92 std::function<void(Status)> done_callback) override; 93 94 private: 95 std::unique_ptr<SecAggUnpackedVectorMap> masked_input_; 96 // Protects masked_input_queue_. 97 absl::Mutex mutex_; 98 std::vector<std::unique_ptr<SecAggVectorMap>> masked_input_queue_ 99 ABSL_GUARDED_BY(mutex_); 100 std::shared_ptr<Accumulator<SecAggUnpackedVectorMap>> 101 masked_input_accumulator_; 102 ServerVariant server_variant_; 103 std::vector<InputVectorSpecification> input_vector_specs_; 104 }; 105 106 } // namespace secagg 107 } // namespace fcp 108 109 #endif // FCP_SECAGG_SERVER_AES_AES_SECAGG_SERVER_PROTOCOL_IMPL_H_ 110