1 /*
2 * Copyright 2018 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/secagg/client/secagg_client.h"
18
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "gmock/gmock.h"
25 #include "gtest/gtest.h"
26 #include "fcp/base/monitoring.h"
27 #include "fcp/secagg/client/send_to_server_interface.h"
28 #include "fcp/secagg/client/state_transition_listener_interface.h"
29 #include "fcp/secagg/shared/aes_ctr_prng_factory.h"
30 #include "fcp/secagg/shared/compute_session_id.h"
31 #include "fcp/secagg/shared/input_vector_specification.h"
32 #include "fcp/secagg/shared/secagg_messages.pb.h"
33 #include "fcp/secagg/shared/secagg_vector.h"
34 #include "fcp/secagg/testing/ecdh_pregenerated_test_keys.h"
35 #include "fcp/secagg/testing/fake_prng.h"
36 #include "fcp/secagg/testing/mock_send_to_server_interface.h"
37 #include "fcp/secagg/testing/mock_state_transition_listener.h"
38 #include "fcp/testing/testing.h"
39
40 // All of the actual client functionality is contained within the
41 // SecAggClient*State classes. This class only tests very basic functionality
42 // of the containing SecAggClient class.
43
44 namespace fcp {
45 namespace secagg {
46 namespace {
47
48 using ::testing::_;
49 using ::testing::Eq;
50 using ::testing::Pointee;
51
TEST(SecAggClientTest,ConstructedWithCorrectState)52 TEST(SecAggClientTest, ConstructedWithCorrectState) {
53 std::vector<InputVectorSpecification> input_vector_specs;
54 input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
55 MockSendToServerInterface* sender = new MockSendToServerInterface();
56 MockStateTransitionListener* transition_listener =
57 new MockStateTransitionListener();
58 SecAggClient client(
59 4, // max_neighbors_expected
60 3, // minimum_surviving_neighbors_for_reconstruction
61 input_vector_specs, std::make_unique<FakePrng>(),
62 std::unique_ptr<SendToServerInterface>(sender),
63 std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
64 std::make_unique<AesCtrPrngFactory>());
65
66 EXPECT_THAT(client.IsAborted(), Eq(false));
67 EXPECT_THAT(client.IsCompletedSuccessfully(), Eq(false));
68 EXPECT_THAT(client.State(), Eq("R0_ADVERTISE_KEYS_INPUT_NOT_SET"));
69 }
70
TEST(SecAggClientTest,StartCausesStateTransition)71 TEST(SecAggClientTest, StartCausesStateTransition) {
72 std::vector<InputVectorSpecification> input_vector_specs;
73 input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
74 MockSendToServerInterface* sender = new MockSendToServerInterface();
75 MockStateTransitionListener* transition_listener =
76 new MockStateTransitionListener();
77 SecAggClient client(
78 4, // max_neighbors_expected
79 3, // minimum_surviving_neighbors_for_reconstruction
80 input_vector_specs, std::make_unique<FakePrng>(),
81 std::unique_ptr<SendToServerInterface>(sender),
82 std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
83
84 std::make_unique<AesCtrPrngFactory>());
85
86 // Message correctness is checked in the tests for the Round 0 classes.
87 EXPECT_CALL(*sender, Send(::testing::_));
88 Status result = client.Start();
89
90 EXPECT_THAT(result.code(), Eq(OK));
91 EXPECT_THAT(client.IsAborted(), Eq(false));
92 EXPECT_THAT(client.IsCompletedSuccessfully(), Eq(false));
93 EXPECT_THAT(client.State(), Eq("R1_SHARE_KEYS_INPUT_NOT_SET"));
94 }
95
TEST(SecAggClientTest,ReceiveMessageReturnValuesAreCorrect)96 TEST(SecAggClientTest, ReceiveMessageReturnValuesAreCorrect) {
97 // The actual behavior of the client upon receipt of messages is tested in the
98 // state class test files; here we just check that ReceiveMessage returns
99 // values correctly.
100 std::vector<InputVectorSpecification> input_vector_specs;
101 input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
102 MockSendToServerInterface* sender = new MockSendToServerInterface();
103 MockStateTransitionListener* transition_listener =
104 new MockStateTransitionListener();
105
106 SecAggClient client(
107 4, // max_neighbors_expected
108 3, // minimum_surviving_neighbors_for_reconstruction
109 input_vector_specs, std::make_unique<FakePrng>(),
110 std::unique_ptr<SendToServerInterface>(sender),
111 std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
112
113 std::make_unique<AesCtrPrngFactory>());
114
115 // Get the client into a state where it can receive a message.
116 ClientToServerWrapperMessage round_0_client_message;
117 EXPECT_CALL(*sender, Send(_))
118 .WillOnce(::testing::SaveArgPointee<0>(&round_0_client_message));
119 EXPECT_THAT(client.Start(), IsOk());
120
121 ServerToClientWrapperMessage round_1_message;
122 EcdhPregeneratedTestKeys ecdh_keys;
123 for (int i = 0; i < 4; ++i) {
124 PairOfPublicKeys* keypair = round_1_message.mutable_share_keys_request()
125 ->add_pairs_of_public_keys();
126 if (i == 1) {
127 *keypair = round_0_client_message.advertise_keys().pair_of_public_keys();
128 } else {
129 keypair->set_enc_pk(ecdh_keys.GetPublicKeyString(2 * i));
130 keypair->set_noise_pk(ecdh_keys.GetPublicKeyString(2 * i + 1));
131 }
132 }
133 round_1_message.mutable_share_keys_request()->set_session_id(
134 ComputeSessionId(round_1_message.share_keys_request()).data);
135
136 EXPECT_CALL(*sender, Send(_));
137
138 // A valid message from the server should return true if it can continue.
139 StatusOr<bool> result = client.ReceiveMessage(round_1_message);
140 ASSERT_THAT(result.ok(), Eq(true));
141 EXPECT_THAT(result.value(), Eq(true));
142
143 // An abort message from the server should return false.
144 ServerToClientWrapperMessage abort_message;
145 abort_message.mutable_abort()->set_early_success(false);
146 result = client.ReceiveMessage(abort_message);
147 ASSERT_THAT(result.ok(), Eq(true));
148 EXPECT_THAT(result.value(), Eq(false));
149
150 // Any other message after abort should raise an error.
151 result = client.ReceiveMessage(abort_message);
152 EXPECT_THAT(result.ok(), Eq(false));
153 }
154
TEST(SecAggClientTest,AbortMovesToCorrectStateAndSendsMessageToServer)155 TEST(SecAggClientTest, AbortMovesToCorrectStateAndSendsMessageToServer) {
156 std::vector<InputVectorSpecification> input_vector_specs;
157 input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
158 MockSendToServerInterface* sender = new MockSendToServerInterface();
159 MockStateTransitionListener* transition_listener =
160 new MockStateTransitionListener();
161
162 SecAggClient client(
163 4, // max_neighbors_expected
164 3, // minimum_surviving_neighbors_for_reconstruction
165 input_vector_specs, std::make_unique<FakePrng>(),
166 std::unique_ptr<SendToServerInterface>(sender),
167 std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
168
169 std::make_unique<AesCtrPrngFactory>());
170
171 std::string error_string =
172 "Abort upon external request for reason <Abort reason>.";
173 ClientToServerWrapperMessage expected_message;
174 expected_message.mutable_abort()->set_diagnostic_info(error_string);
175 EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
176
177 Status result = client.Abort("Abort reason");
178 EXPECT_THAT(result.code(), Eq(OK));
179 EXPECT_THAT(client.State(), Eq("ABORTED"));
180 EXPECT_THAT(client.ErrorMessage().value(), Eq(error_string));
181 }
182
TEST(SecAggClientTest,AbortWithNoMessageMovesToCorrectStateAndSendsMessageToServer)183 TEST(SecAggClientTest,
184 AbortWithNoMessageMovesToCorrectStateAndSendsMessageToServer) {
185 std::vector<InputVectorSpecification> input_vector_specs;
186 input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
187 MockSendToServerInterface* sender = new MockSendToServerInterface();
188 MockStateTransitionListener* transition_listener =
189 new MockStateTransitionListener();
190
191 SecAggClient client(
192 4, // max_neighbors_expected
193 3, // minimum_surviving_neighbors_for_reconstruction
194 input_vector_specs, std::make_unique<FakePrng>(),
195 std::unique_ptr<SendToServerInterface>(sender),
196 std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
197
198 std::make_unique<AesCtrPrngFactory>());
199
200 std::string error_string =
201 "Abort upon external request for reason <unknown reason>.";
202 ClientToServerWrapperMessage expected_message;
203 expected_message.mutable_abort()->set_diagnostic_info(error_string);
204 EXPECT_CALL(*sender, Send(Pointee(EqualsProto(expected_message))));
205
206 Status result = client.Abort();
207 EXPECT_THAT(result.code(), Eq(OK));
208 EXPECT_THAT(client.State(), Eq("ABORTED"));
209 EXPECT_THAT(client.ErrorMessage().value(), Eq(error_string));
210 }
211
TEST(SecAggClientTest,ErrorMessageRaisesErrorStatusIfNotAborted)212 TEST(SecAggClientTest, ErrorMessageRaisesErrorStatusIfNotAborted) {
213 std::vector<InputVectorSpecification> input_vector_specs;
214 input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
215 MockSendToServerInterface* sender = new MockSendToServerInterface();
216 MockStateTransitionListener* transition_listener =
217 new MockStateTransitionListener();
218
219 SecAggClient client(
220 4, // max_neighbors_expected
221 3, // minimum_surviving_neighbors_for_reconstruction
222 input_vector_specs, std::make_unique<FakePrng>(),
223 std::unique_ptr<SendToServerInterface>(sender),
224 std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
225
226 std::make_unique<AesCtrPrngFactory>());
227
228 EXPECT_THAT(client.ErrorMessage().ok(), Eq(false));
229 }
230
TEST(SecAggClientTest,SetInputChangesStateOnlyOnce)231 TEST(SecAggClientTest, SetInputChangesStateOnlyOnce) {
232 std::vector<InputVectorSpecification> input_vector_specs;
233 input_vector_specs.push_back(InputVectorSpecification("test", 4, 32));
234 MockSendToServerInterface* sender = new MockSendToServerInterface();
235 MockStateTransitionListener* transition_listener =
236 new MockStateTransitionListener();
237
238 SecAggClient client(
239 4, // max_neighbors_expected
240 3, // minimum_surviving_neighbors_for_reconstruction
241 input_vector_specs, std::make_unique<FakePrng>(),
242 std::unique_ptr<SendToServerInterface>(sender),
243 std::unique_ptr<StateTransitionListenerInterface>(transition_listener),
244
245 std::make_unique<AesCtrPrngFactory>());
246
247 auto input_map = std::make_unique<SecAggVectorMap>();
248 input_map->emplace("test", SecAggVector({5, 8, 22, 30}, 32));
249
250 Status result = client.SetInput(std::move(input_map));
251 EXPECT_THAT(result.code(), Eq(OK));
252
253 auto input_map2 = std::make_unique<SecAggVectorMap>();
254 input_map2->emplace("test", SecAggVector({5, 8, 22, 30}, 32));
255 result = client.SetInput(std::move(input_map));
256 EXPECT_THAT(result.code(), Eq(FAILED_PRECONDITION));
257 EXPECT_THAT(client.State(), Eq("R0_ADVERTISE_KEYS_INPUT_SET"));
258 }
259
260 } // namespace
261 } // namespace secagg
262 } // namespace fcp
263