xref: /aosp_15_r20/external/federated-compute/fcp/secagg/client/secagg_client_test.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2018 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/client/secagg_client.h"
18 
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "gmock/gmock.h"
25 #include "gtest/gtest.h"
26 #include "fcp/base/monitoring.h"
27 #include "fcp/secagg/client/send_to_server_interface.h"
28 #include "fcp/secagg/client/state_transition_listener_interface.h"
29 #include "fcp/secagg/shared/aes_ctr_prng_factory.h"
30 #include "fcp/secagg/shared/compute_session_id.h"
31 #include "fcp/secagg/shared/input_vector_specification.h"
32 #include "fcp/secagg/shared/secagg_messages.pb.h"
33 #include "fcp/secagg/shared/secagg_vector.h"
34 #include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h"
35 #include "fcp/secagg/testing/fake_prng.h"
36 #include "fcp/secagg/testing/mock_send_to_server_interface.h"
37 #include "fcp/secagg/testing/mock_state_transition_listener.h"
38 #include "fcp/testing/testing.h"
39 
40 // All of the actual client functionality is contained within the
41 // SecAggClient*State classes. This class only tests very basic functionality
42 // of the containing SecAggClient class.
43 
44 namespace fcp {
45 namespace secagg {
46 namespace {
47 
48 using ::testing::_;
49 using ::testing::Eq;
50 using ::testing::Pointee;
51 
TEST(SecAggClientTest,ConstructedWithCorrectState)52 TEST(SecAggClientTest, ConstructedWithCorrectState) {
53   std::vector<InputVectorSpecification> input_vector_specs;
54   input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
55   MockSendToServerInterface* sender = new MockSendToServerInterface();
56   MockStateTransitionListener* transition_listener =
57       new MockStateTransitionListener();
58   SecAggClient client(
59       4,  // max_neighbors_expected
60       3,  // minimum_surviving_neighbors_for_reconstruction
61       input_vector_specs, std::make_unique<FakePrng>(),
62       std::unique_ptr<SendToServerInterface>(sender),
63       std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
64       std::make_unique<AesCtrPrngFactory>());
65 
66   EXPECT_THAT(client.IsAborted(), Eq(false));
67   EXPECT_THAT(client.IsCompletedSuccessfully(), Eq(false));
68   EXPECT_THAT(client.State(), Eq("R0_ADVERTISE_KEYS_INPUT_NOT_SET"));
69 }
70 
TEST(SecAggClientTest,StartCausesStateTransition)71 TEST(SecAggClientTest, StartCausesStateTransition) {
72   std::vector<InputVectorSpecification> input_vector_specs;
73   input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
74   MockSendToServerInterface* sender = new MockSendToServerInterface();
75   MockStateTransitionListener* transition_listener =
76       new MockStateTransitionListener();
77   SecAggClient client(
78       4,  // max_neighbors_expected
79       3,  // minimum_surviving_neighbors_for_reconstruction
80       input_vector_specs, std::make_unique<FakePrng>(),
81       std::unique_ptr<SendToServerInterface>(sender),
82       std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
83 
84       std::make_unique<AesCtrPrngFactory>());
85 
86   // Message correctness is checked in the tests for the Round 0 classes.
87   EXPECT_CALL(*sender, Send(::testing::_));
88   Status result = client.Start();
89 
90   EXPECT_THAT(result.code(), Eq(OK));
91   EXPECT_THAT(client.IsAborted(), Eq(false));
92   EXPECT_THAT(client.IsCompletedSuccessfully(), Eq(false));
93   EXPECT_THAT(client.State(), Eq("R1_SHARE_KEYS_INPUT_NOT_SET"));
94 }
95 
TEST(SecAggClientTest,ReceiveMessageReturnValuesAreCorrect)96 TEST(SecAggClientTest, ReceiveMessageReturnValuesAreCorrect) {
97   // The actual behavior of the client upon receipt of messages is tested in the
98   // state class test files; here we just check that ReceiveMessage returns
99   // values correctly.
100   std::vector<InputVectorSpecification> input_vector_specs;
101   input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
102   MockSendToServerInterface* sender = new MockSendToServerInterface();
103   MockStateTransitionListener* transition_listener =
104       new MockStateTransitionListener();
105 
106   SecAggClient client(
107       4,  // max_neighbors_expected
108       3,  // minimum_surviving_neighbors_for_reconstruction
109       input_vector_specs, std::make_unique<FakePrng>(),
110       std::unique_ptr<SendToServerInterface>(sender),
111       std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
112 
113       std::make_unique<AesCtrPrngFactory>());
114 
115   // Get the client into a state where it can receive a message.
116   ClientToServerWrapperMessage round_0_client_message;
117   EXPECT_CALL(*sender, Send(_))
118       .WillOnce(::testing::SaveArgPointee<0>(&round_0_client_message));
119   EXPECT_THAT(client.Start(), IsOk());
120 
121   ServerToClientWrapperMessage round_1_message;
122   EcdhPregeneratedTestKeys ecdh_keys;
123   for (int i = 0; i < 4; ++i) {
124     PairOfPublicKeys* keypair = round_1_message.mutable_share_keys_request()
125                                     ->add_pairs_of_public_keys();
126     if (i == 1) {
127       *keypair = round_0_client_message.advertise_keys().pair_of_public_keys();
128     } else {
129       keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(2 * i));
130       keypair->set_noise_pk(ecdh_keys.GetPublicKeyString(2 * i + 1));
131     }
132   }
133   round_1_message.mutable_share_keys_request()->set_session_id(
134       ComputeSessionId(round_1_message.share_keys_request()).data);
135 
136   EXPECT_CALL(*sender, Send(_));
137 
138   // A valid message from the server should return true if it can continue.
139   StatusOr<bool> result = client.ReceiveMessage(round_1_message);
140   ASSERT_THAT(result.ok(), Eq(true));
141   EXPECT_THAT(result.value(), Eq(true));
142 
143   // An abort message from the server should return false.
144   ServerToClientWrapperMessage abort_message;
145   abort_message.mutable_abort()->set_early_success(false);
146   result = client.ReceiveMessage(abort_message);
147   ASSERT_THAT(result.ok(), Eq(true));
148   EXPECT_THAT(result.value(), Eq(false));
149 
150   // Any other message after abort should raise an error.
151   result = client.ReceiveMessage(abort_message);
152   EXPECT_THAT(result.ok(), Eq(false));
153 }
154 
TEST(SecAggClientTest,AbortMovesToCorrectStateAndSendsMessageToServer)155 TEST(SecAggClientTest, AbortMovesToCorrectStateAndSendsMessageToServer) {
156   std::vector<InputVectorSpecification> input_vector_specs;
157   input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
158   MockSendToServerInterface* sender = new MockSendToServerInterface();
159   MockStateTransitionListener* transition_listener =
160       new MockStateTransitionListener();
161 
162   SecAggClient client(
163       4,  // max_neighbors_expected
164       3,  // minimum_surviving_neighbors_for_reconstruction
165       input_vector_specs, std::make_unique<FakePrng>(),
166       std::unique_ptr<SendToServerInterface>(sender),
167       std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
168 
169       std::make_unique<AesCtrPrngFactory>());
170 
171   std::string error_string =
172       "Abort upon external request for reason <Abort reason>.";
173   ClientToServerWrapperMessage expected_message;
174   expected_message.mutable_abort()->set_diagnostic_info(error_string);
175   EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
176 
177   Status result = client.Abort("Abort reason");
178   EXPECT_THAT(result.code(), Eq(OK));
179   EXPECT_THAT(client.State(), Eq("ABORTED"));
180   EXPECT_THAT(client.ErrorMessage().value(), Eq(error_string));
181 }
182 
TEST(SecAggClientTest,AbortWithNoMessageMovesToCorrectStateAndSendsMessageToServer)183 TEST(SecAggClientTest,
184      AbortWithNoMessageMovesToCorrectStateAndSendsMessageToServer) {
185   std::vector<InputVectorSpecification> input_vector_specs;
186   input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
187   MockSendToServerInterface* sender = new MockSendToServerInterface();
188   MockStateTransitionListener* transition_listener =
189       new MockStateTransitionListener();
190 
191   SecAggClient client(
192       4,  // max_neighbors_expected
193       3,  // minimum_surviving_neighbors_for_reconstruction
194       input_vector_specs, std::make_unique<FakePrng>(),
195       std::unique_ptr<SendToServerInterface>(sender),
196       std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
197 
198       std::make_unique<AesCtrPrngFactory>());
199 
200   std::string error_string =
201       "Abort upon external request for reason <unknown reason>.";
202   ClientToServerWrapperMessage expected_message;
203   expected_message.mutable_abort()->set_diagnostic_info(error_string);
204   EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
205 
206   Status result = client.Abort();
207   EXPECT_THAT(result.code(), Eq(OK));
208   EXPECT_THAT(client.State(), Eq("ABORTED"));
209   EXPECT_THAT(client.ErrorMessage().value(), Eq(error_string));
210 }
211 
TEST(SecAggClientTest,ErrorMessageRaisesErrorStatusIfNotAborted)212 TEST(SecAggClientTest, ErrorMessageRaisesErrorStatusIfNotAborted) {
213   std::vector<InputVectorSpecification> input_vector_specs;
214   input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
215   MockSendToServerInterface* sender = new MockSendToServerInterface();
216   MockStateTransitionListener* transition_listener =
217       new MockStateTransitionListener();
218 
219   SecAggClient client(
220       4,  // max_neighbors_expected
221       3,  // minimum_surviving_neighbors_for_reconstruction
222       input_vector_specs, std::make_unique<FakePrng>(),
223       std::unique_ptr<SendToServerInterface>(sender),
224       std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
225 
226       std::make_unique<AesCtrPrngFactory>());
227 
228   EXPECT_THAT(client.ErrorMessage().ok(), Eq(false));
229 }
230 
TEST(SecAggClientTest,SetInputChangesStateOnlyOnce)231 TEST(SecAggClientTest, SetInputChangesStateOnlyOnce) {
232   std::vector<InputVectorSpecification> input_vector_specs;
233   input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
234   MockSendToServerInterface* sender = new MockSendToServerInterface();
235   MockStateTransitionListener* transition_listener =
236       new MockStateTransitionListener();
237 
238   SecAggClient client(
239       4,  // max_neighbors_expected
240       3,  // minimum_surviving_neighbors_for_reconstruction
241       input_vector_specs, std::make_unique<FakePrng>(),
242       std::unique_ptr<SendToServerInterface>(sender),
243       std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
244 
245       std::make_unique<AesCtrPrngFactory>());
246 
247   auto input_map = std::make_unique<SecAggVectorMap>();
248   input_map->emplace("test", SecAggVector({5, 8, 22, 30}, 32));
249 
250   Status result = client.SetInput(std::move(input_map));
251   EXPECT_THAT(result.code(), Eq(OK));
252 
253   auto input_map2 = std::make_unique<SecAggVectorMap>();
254   input_map2->emplace("test", SecAggVector({5, 8, 22, 30}, 32));
255   result = client.SetInput(std::move(input_map));
256   EXPECT_THAT(result.code(), Eq(FAILED_PRECONDITION));
257   EXPECT_THAT(client.State(), Eq("R0_ADVERTISE_KEYS_INPUT_SET"));
258 }
259 
260 }  // namespace
261 }  // namespace secagg
262 }  // namespace fcp
263