1 /*
2 * Copyright 2022 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/aggregation/protocol/simple_aggregation/simple_aggregation_protocol.h"
18
19 #include <atomic>
20 #include <functional>
21 #include <memory>
22 #include <utility>
23
24 #include "gmock/gmock.h"
25 #include "gtest/gtest.h"
26 #include "absl/status/status.h"
27 #include "absl/status/statusor.h"
28 #include "absl/strings/cord.h"
29 #include "absl/synchronization/notification.h"
30 #include "fcp/aggregation/core/agg_vector_aggregator.h"
31 #include "fcp/aggregation/core/tensor.h"
32 #include "fcp/aggregation/core/tensor_aggregator_factory.h"
33 #include "fcp/aggregation/core/tensor_aggregator_registry.h"
34 #include "fcp/aggregation/protocol/aggregation_protocol_messages.pb.h"
35 #include "fcp/aggregation/protocol/checkpoint_builder.h"
36 #include "fcp/aggregation/protocol/checkpoint_parser.h"
37 #include "fcp/aggregation/protocol/configuration.pb.h"
38 #include "fcp/aggregation/protocol/testing/test_callback.h"
39 #include "fcp/aggregation/testing/test_data.h"
40 #include "fcp/aggregation/testing/testing.h"
41 #include "fcp/base/monitoring.h"
42 #include "fcp/base/scheduler.h"
43 #include "fcp/testing/testing.h"
44
45 namespace fcp::aggregation {
46 namespace {
47
48 using ::testing::_;
49 using ::testing::ByMove;
50 using ::testing::Eq;
51 using ::testing::Invoke;
52 using ::testing::Return;
53 using ::testing::StrEq;
54
55 // TODO(team): Consider moving mock classes into a separate test library.
56 class MockCheckpointParser : public CheckpointParser {
57 public:
58 MOCK_METHOD(absl::StatusOr<Tensor>, GetTensor, (const std::string& name),
59 (const override));
60 };
61
62 class MockCheckpointParserFactory : public CheckpointParserFactory {
63 public:
64 MOCK_METHOD(absl::StatusOr<std::unique_ptr<CheckpointParser>>, Create,
65 (const absl::Cord& serialized_checkpoint), (const override));
66 };
67
68 class MockCheckpointBuilder : public CheckpointBuilder {
69 public:
70 MOCK_METHOD(absl::Status, Add,
71 (const std::string& name, const Tensor& tensor), (override));
72 MOCK_METHOD(absl::StatusOr<absl::Cord>, Build, (), (override));
73 };
74
75 class MockCheckpointBuilderFactory : public CheckpointBuilderFactory {
76 public:
77 MOCK_METHOD(std::unique_ptr<CheckpointBuilder>, Create, (), (const override));
78 };
79
80 class MockResourceResolver : public ResourceResolver {
81 public:
82 MOCK_METHOD(absl::StatusOr<absl::Cord>, RetrieveResource,
83 (int64_t client_id, const std::string& uri), (override));
84 };
85
86 class SimpleAggregationProtocolTest : public ::testing::Test {
87 protected:
88 // Returns default configuration.
89 Configuration default_configuration() const;
90
91 // Returns the default instance of checkpoint bilder;
ExpectCheckpointBuilder()92 MockCheckpointBuilder& ExpectCheckpointBuilder() {
93 MockCheckpointBuilder& checkpoint_builder = *wrapped_checkpoint_builder_;
94 EXPECT_CALL(checkpoint_builder_factory_, Create())
95 .WillOnce(Return(ByMove(std::move(wrapped_checkpoint_builder_))));
96 EXPECT_CALL(checkpoint_builder, Build()).WillOnce(Return(absl::Cord{}));
97 return checkpoint_builder;
98 }
99
100 // Creates an instance of SimpleAggregationProtocol with the specified config.
101 std::unique_ptr<SimpleAggregationProtocol> CreateProtocol(
102 Configuration config);
103
104 // Creates an instance of SimpleAggregationProtocol with the default config.
CreateProtocolWithDefaultConfig()105 std::unique_ptr<SimpleAggregationProtocol> CreateProtocolWithDefaultConfig() {
106 return CreateProtocol(default_configuration());
107 }
108
109 MockAggregationProtocolCallback callback_;
110
111 MockCheckpointParserFactory checkpoint_parser_factory_;
112 MockCheckpointBuilderFactory checkpoint_builder_factory_;
113 MockResourceResolver resource_resolver_;
114
115 private:
116 std::unique_ptr<MockCheckpointBuilder> wrapped_checkpoint_builder_ =
117 std::make_unique<MockCheckpointBuilder>();
118 };
119
default_configuration() const120 Configuration SimpleAggregationProtocolTest::default_configuration() const {
121 // One "federated_sum" intrinsic with a single scalar int32 tensor.
122 return PARSE_TEXT_PROTO(R"pb(
123 aggregation_configs {
124 intrinsic_uri: "federated_sum"
125 intrinsic_args {
126 input_tensor {
127 name: "foo"
128 dtype: DT_INT32
129 shape {}
130 }
131 }
132 output_tensors {
133 name: "foo_out"
134 dtype: DT_INT32
135 shape {}
136 }
137 }
138 )pb");
139 }
140
141 std::unique_ptr<SimpleAggregationProtocol>
CreateProtocol(Configuration config)142 SimpleAggregationProtocolTest::CreateProtocol(Configuration config) {
143 // Verify that the protocol can be created successfully.
144 absl::StatusOr<std::unique_ptr<SimpleAggregationProtocol>>
145 protocol_or_status = SimpleAggregationProtocol::Create(
146 config, &callback_, &checkpoint_parser_factory_,
147 &checkpoint_builder_factory_, &resource_resolver_);
148 EXPECT_THAT(protocol_or_status, IsOk());
149 return std::move(protocol_or_status).value();
150 }
151
MakeClientMessage()152 ClientMessage MakeClientMessage() {
153 ClientMessage message;
154 message.mutable_simple_aggregation()->mutable_input()->set_inline_bytes("");
155 return message;
156 }
157
TEST_F(SimpleAggregationProtocolTest,Create_UnsupportedNumberOfInputs)158 TEST_F(SimpleAggregationProtocolTest, Create_UnsupportedNumberOfInputs) {
159 Configuration config_message = PARSE_TEXT_PROTO(R"pb(
160 aggregation_configs {
161 intrinsic_uri: "federated_sum"
162 intrinsic_args {
163 input_tensor {
164 name: "foo"
165 dtype: DT_INT32
166 shape {}
167 }
168 }
169 intrinsic_args {
170 input_tensor {
171 name: "bar"
172 dtype: DT_INT32
173 shape {}
174 }
175 }
176 output_tensors {
177 name: "foo_out"
178 dtype: DT_INT32
179 shape {}
180 }
181 }
182 )pb");
183 EXPECT_THAT(SimpleAggregationProtocol::Create(
184 config_message, &callback_, &checkpoint_parser_factory_,
185 &checkpoint_builder_factory_, &resource_resolver_),
186 IsCode(INVALID_ARGUMENT));
187 }
188
TEST_F(SimpleAggregationProtocolTest,Create_UnsupportedNumberOfOutputs)189 TEST_F(SimpleAggregationProtocolTest, Create_UnsupportedNumberOfOutputs) {
190 Configuration config_message = PARSE_TEXT_PROTO(R"pb(
191 aggregation_configs {
192 intrinsic_uri: "federated_sum"
193 intrinsic_args {
194 input_tensor {
195 name: "foo"
196 dtype: DT_INT32
197 shape {}
198 }
199 }
200 output_tensors {
201 name: "foo_out"
202 dtype: DT_INT32
203 shape {}
204 }
205 output_tensors {
206 name: "bar_out"
207 dtype: DT_INT32
208 shape {}
209 }
210 }
211 )pb");
212 EXPECT_THAT(SimpleAggregationProtocol::Create(
213 config_message, &callback_, &checkpoint_parser_factory_,
214 &checkpoint_builder_factory_, &resource_resolver_),
215 IsCode(INVALID_ARGUMENT));
216 }
217
TEST_F(SimpleAggregationProtocolTest,Create_UnsupportedInputType)218 TEST_F(SimpleAggregationProtocolTest, Create_UnsupportedInputType) {
219 Configuration config_message = PARSE_TEXT_PROTO(R"pb(
220 aggregation_configs {
221 intrinsic_uri: "federated_sum"
222 intrinsic_args { parameter {} }
223 output_tensors {
224 name: "foo_out"
225 dtype: DT_INT32
226 shape {}
227 }
228 }
229 )pb");
230 EXPECT_THAT(SimpleAggregationProtocol::Create(
231 config_message, &callback_, &checkpoint_parser_factory_,
232 &checkpoint_builder_factory_, &resource_resolver_),
233 IsCode(INVALID_ARGUMENT));
234 }
235
TEST_F(SimpleAggregationProtocolTest,Create_UnsupportedIntrinsicUri)236 TEST_F(SimpleAggregationProtocolTest, Create_UnsupportedIntrinsicUri) {
237 Configuration config_message = PARSE_TEXT_PROTO(R"pb(
238 aggregation_configs {
239 intrinsic_uri: "unsupported_xyz"
240 intrinsic_args {
241 input_tensor {
242 name: "foo"
243 dtype: DT_INT32
244 shape {}
245 }
246 }
247 output_tensors {
248 name: "foo_out"
249 dtype: DT_INT32
250 shape {}
251 }
252 }
253 )pb");
254 EXPECT_THAT(SimpleAggregationProtocol::Create(
255 config_message, &callback_, &checkpoint_parser_factory_,
256 &checkpoint_builder_factory_, &resource_resolver_),
257 IsCode(INVALID_ARGUMENT));
258 }
259
TEST_F(SimpleAggregationProtocolTest,Create_UnsupportedInputSpec)260 TEST_F(SimpleAggregationProtocolTest, Create_UnsupportedInputSpec) {
261 Configuration config_message = PARSE_TEXT_PROTO(R"pb(
262 aggregation_configs {
263 intrinsic_uri: "federated_sum"
264 intrinsic_args {
265 input_tensor {
266 name: "foo"
267 dtype: DT_INT32
268 shape { dim { size: -1 } }
269 }
270 }
271 output_tensors {
272 name: "foo_out"
273 dtype: DT_INT32
274 shape {}
275 }
276 }
277 )pb");
278 EXPECT_THAT(SimpleAggregationProtocol::Create(
279 config_message, &callback_, &checkpoint_parser_factory_,
280 &checkpoint_builder_factory_, &resource_resolver_),
281 IsCode(INVALID_ARGUMENT));
282 }
283
TEST_F(SimpleAggregationProtocolTest,Create_UnmatchingInputAndOutputDataType)284 TEST_F(SimpleAggregationProtocolTest, Create_UnmatchingInputAndOutputDataType) {
285 Configuration config_message = PARSE_TEXT_PROTO(R"pb(
286 aggregation_configs {
287 intrinsic_uri: "federated_sum"
288 intrinsic_args {
289 input_tensor {
290 name: "foo"
291 dtype: DT_INT32
292 shape {}
293 }
294 }
295 output_tensors {
296 name: "foo_out"
297 dtype: DT_FLOAT
298 shape {}
299 }
300 }
301 )pb");
302 EXPECT_THAT(SimpleAggregationProtocol::Create(
303 config_message, &callback_, &checkpoint_parser_factory_,
304 &checkpoint_builder_factory_, &resource_resolver_),
305 IsCode(INVALID_ARGUMENT));
306 }
307
TEST_F(SimpleAggregationProtocolTest,Create_UnmatchingInputAndOutputShape)308 TEST_F(SimpleAggregationProtocolTest, Create_UnmatchingInputAndOutputShape) {
309 Configuration config_message = PARSE_TEXT_PROTO(R"pb(
310 aggregation_configs {
311 intrinsic_uri: "federated_sum"
312 intrinsic_args {
313 input_tensor {
314 name: "foo"
315 dtype: DT_INT32
316 shape { dim { size: 1 } }
317 }
318 }
319 output_tensors {
320 name: "foo_out"
321 dtype: DT_INT32
322 shape { dim { size: 2 } }
323 }
324 }
325 )pb");
326 EXPECT_THAT(SimpleAggregationProtocol::Create(
327 config_message, &callback_, &checkpoint_parser_factory_,
328 &checkpoint_builder_factory_, &resource_resolver_),
329 IsCode(INVALID_ARGUMENT));
330 }
331
TEST_F(SimpleAggregationProtocolTest,StartProtocol_Success)332 TEST_F(SimpleAggregationProtocolTest, StartProtocol_Success) {
333 auto protocol = CreateProtocolWithDefaultConfig();
334 EXPECT_THAT(protocol->Start(3), IsOk());
335 EXPECT_THAT(
336 protocol->GetStatus(),
337 EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_pending: 3")));
338 }
339
340 // TODO(team): Add similar tests for other callbacks.
TEST_F(SimpleAggregationProtocolTest,StartProtocol_AcceptClientsProtocolReentrace)341 TEST_F(SimpleAggregationProtocolTest,
342 StartProtocol_AcceptClientsProtocolReentrace) {
343 // This verifies that the protocol can be re-entered from the callback.
344 auto protocol = CreateProtocolWithDefaultConfig();
345
346 EXPECT_CALL(callback_, OnAcceptClients(0, 1, _)).WillOnce(Invoke([&]() {
347 EXPECT_THAT(
348 protocol->GetStatus(),
349 EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_pending: 1")));
350 }));
351
352 EXPECT_THAT(protocol->Start(1), IsOk());
353 }
354
TEST_F(SimpleAggregationProtocolTest,StartProtocol_MultipleCalls)355 TEST_F(SimpleAggregationProtocolTest, StartProtocol_MultipleCalls) {
356 auto protocol = CreateProtocolWithDefaultConfig();
357 EXPECT_CALL(callback_, OnAcceptClients).Times(1);
358 EXPECT_THAT(protocol->Start(1), IsOk());
359 // The second Start call must fail.
360 EXPECT_THAT(protocol->Start(1), IsCode(FAILED_PRECONDITION));
361 }
362
TEST_F(SimpleAggregationProtocolTest,StartProtocol_ZeroClients)363 TEST_F(SimpleAggregationProtocolTest, StartProtocol_ZeroClients) {
364 auto protocol = CreateProtocolWithDefaultConfig();
365 EXPECT_CALL(callback_, OnAcceptClients).Times(0);
366 EXPECT_THAT(protocol->Start(0), IsOk());
367 }
368
TEST_F(SimpleAggregationProtocolTest,StartProtocol_NegativeClients)369 TEST_F(SimpleAggregationProtocolTest, StartProtocol_NegativeClients) {
370 auto protocol = CreateProtocolWithDefaultConfig();
371 EXPECT_CALL(callback_, OnAcceptClients).Times(0);
372 EXPECT_THAT(protocol->Start(-1), IsCode(INVALID_ARGUMENT));
373 }
374
TEST_F(SimpleAggregationProtocolTest,AddClients_Success)375 TEST_F(SimpleAggregationProtocolTest, AddClients_Success) {
376 auto protocol = CreateProtocolWithDefaultConfig();
377
378 EXPECT_CALL(callback_, OnAcceptClients(0, 1, _));
379 EXPECT_THAT(protocol->Start(1), IsOk());
380 EXPECT_THAT(
381 protocol->GetStatus(),
382 EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_pending: 1")));
383
384 EXPECT_CALL(callback_, OnAcceptClients(1, 3, _));
385 EXPECT_THAT(protocol->AddClients(3), IsOk());
386 EXPECT_THAT(
387 protocol->GetStatus(),
388 EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_pending: 4")));
389 }
390
TEST_F(SimpleAggregationProtocolTest,AddClients_ProtocolNotStarted)391 TEST_F(SimpleAggregationProtocolTest, AddClients_ProtocolNotStarted) {
392 auto protocol = CreateProtocolWithDefaultConfig();
393 // Must fail because the protocol isn't started.
394 EXPECT_CALL(callback_, OnAcceptClients).Times(0);
395 EXPECT_THAT(protocol->AddClients(1), IsCode(FAILED_PRECONDITION));
396 }
397
TEST_F(SimpleAggregationProtocolTest,ReceiveClientMessage_ProtocolNotStarted)398 TEST_F(SimpleAggregationProtocolTest, ReceiveClientMessage_ProtocolNotStarted) {
399 auto protocol = CreateProtocolWithDefaultConfig();
400 // Must fail because the protocol isn't started.
401 EXPECT_THAT(protocol->ReceiveClientMessage(0, MakeClientMessage()),
402 IsCode(FAILED_PRECONDITION));
403 }
404
TEST_F(SimpleAggregationProtocolTest,ReceiveClientMessage_InvalidMessage)405 TEST_F(SimpleAggregationProtocolTest, ReceiveClientMessage_InvalidMessage) {
406 auto protocol = CreateProtocolWithDefaultConfig();
407 ClientMessage message;
408 // Empty message without SimpleAggregation.
409 EXPECT_THAT(protocol->ReceiveClientMessage(0, message),
410 IsCode(INVALID_ARGUMENT));
411 // Message with SimpleAggregation but without the input.
412 message.mutable_simple_aggregation();
413 EXPECT_THAT(protocol->ReceiveClientMessage(0, message),
414 IsCode(INVALID_ARGUMENT));
415 // Message with empty input.
416 message.mutable_simple_aggregation()->mutable_input();
417 EXPECT_THAT(protocol->ReceiveClientMessage(0, message),
418 IsCode(INVALID_ARGUMENT));
419 }
420
TEST_F(SimpleAggregationProtocolTest,ReceiveClientMessage_InvalidClientId)421 TEST_F(SimpleAggregationProtocolTest, ReceiveClientMessage_InvalidClientId) {
422 auto protocol = CreateProtocolWithDefaultConfig();
423 EXPECT_CALL(callback_, OnAcceptClients);
424 EXPECT_THAT(protocol->Start(1), IsOk());
425 // Must fail for the client_id -1 and 2.
426 EXPECT_THAT(protocol->ReceiveClientMessage(-1, MakeClientMessage()),
427 IsCode(INVALID_ARGUMENT));
428 EXPECT_THAT(protocol->ReceiveClientMessage(2, MakeClientMessage()),
429 IsCode(INVALID_ARGUMENT));
430 }
431
TEST_F(SimpleAggregationProtocolTest,ReceiveClientMessage_DuplicateClientIdInputs)432 TEST_F(SimpleAggregationProtocolTest,
433 ReceiveClientMessage_DuplicateClientIdInputs) {
434 auto protocol = CreateProtocolWithDefaultConfig();
435 EXPECT_CALL(callback_, OnAcceptClients);
436 EXPECT_THAT(protocol->Start(2), IsOk());
437
438 auto parser = std::make_unique<MockCheckpointParser>();
439 EXPECT_CALL(*parser, GetTensor(StrEq("foo"))).WillOnce(Invoke([] {
440 return Tensor::Create(DT_INT32, {}, CreateTestData({1}));
441 }));
442
443 EXPECT_CALL(checkpoint_parser_factory_, Create(_))
444 .WillOnce(Return(ByMove(std::move(parser))));
445
446 EXPECT_THAT(protocol->ReceiveClientMessage(0, MakeClientMessage()), IsOk());
447 // The second input for the same client must succeed to without changing the
448 // aggregated state.
449 EXPECT_THAT(protocol->ReceiveClientMessage(0, MakeClientMessage()), IsOk());
450 EXPECT_THAT(protocol->GetStatus(),
451 EqualsProto<StatusMessage>(PARSE_TEXT_PROTO(
452 "num_clients_pending: 1 num_clients_completed: 1 "
453 "num_inputs_aggregated_and_included: 1")));
454 }
455
TEST_F(SimpleAggregationProtocolTest,ReceiveClientMessage_AfterClosingClient)456 TEST_F(SimpleAggregationProtocolTest, ReceiveClientMessage_AfterClosingClient) {
457 auto protocol = CreateProtocolWithDefaultConfig();
458 EXPECT_CALL(callback_, OnAcceptClients);
459 EXPECT_THAT(protocol->Start(1), IsOk());
460
461 EXPECT_THAT(protocol->CloseClient(0, absl::OkStatus()), IsOk());
462 EXPECT_THAT(protocol->GetStatus(),
463 EqualsProto<StatusMessage>(PARSE_TEXT_PROTO(
464 "num_clients_completed: 1 num_inputs_discarded: 1")));
465 // This must succeed to without changing the aggregated state.
466 EXPECT_THAT(protocol->ReceiveClientMessage(0, MakeClientMessage()), IsOk());
467 EXPECT_THAT(protocol->GetStatus(),
468 EqualsProto<StatusMessage>(PARSE_TEXT_PROTO(
469 "num_clients_completed: 1 num_inputs_discarded: 1")));
470 }
471
TEST_F(SimpleAggregationProtocolTest,ReceiveClientMessage_FailureToParseInput)472 TEST_F(SimpleAggregationProtocolTest,
473 ReceiveClientMessage_FailureToParseInput) {
474 auto protocol = CreateProtocolWithDefaultConfig();
475 EXPECT_CALL(callback_, OnAcceptClients);
476 EXPECT_THAT(protocol->Start(1), IsOk());
477
478 EXPECT_CALL(checkpoint_parser_factory_, Create(_))
479 .WillOnce(
480 Return(ByMove(absl::InvalidArgumentError("Invalid checkpoint"))));
481
482 EXPECT_CALL(callback_, OnCloseClient(Eq(0), IsCode(INVALID_ARGUMENT)));
483
484 // Receiving the client input should still succeed.
485 EXPECT_THAT(protocol->ReceiveClientMessage(0, MakeClientMessage()), IsOk());
486 EXPECT_THAT(
487 protocol->GetStatus(),
488 EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_failed: 1")));
489 }
490
TEST_F(SimpleAggregationProtocolTest,ReceiveClientMessage_MissingTensor)491 TEST_F(SimpleAggregationProtocolTest, ReceiveClientMessage_MissingTensor) {
492 auto protocol = CreateProtocolWithDefaultConfig();
493 EXPECT_CALL(callback_, OnAcceptClients);
494 EXPECT_THAT(protocol->Start(1), IsOk());
495
496 auto parser = std::make_unique<MockCheckpointParser>();
497 EXPECT_CALL(*parser, GetTensor(StrEq("foo")))
498 .WillOnce(Return(ByMove(absl::NotFoundError("Missing tensor foo"))));
499
500 EXPECT_CALL(checkpoint_parser_factory_, Create(_))
501 .WillOnce(Return(ByMove(std::move(parser))));
502
503 EXPECT_CALL(callback_, OnCloseClient(Eq(0), IsCode(NOT_FOUND)));
504
505 // Receiving the client input should still succeed.
506 EXPECT_THAT(protocol->ReceiveClientMessage(0, MakeClientMessage()), IsOk());
507 EXPECT_THAT(
508 protocol->GetStatus(),
509 EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_failed: 1")));
510 }
511
TEST_F(SimpleAggregationProtocolTest,ReceiveClientMessage_MismatchingTensor)512 TEST_F(SimpleAggregationProtocolTest, ReceiveClientMessage_MismatchingTensor) {
513 auto protocol = CreateProtocolWithDefaultConfig();
514 EXPECT_CALL(callback_, OnAcceptClients);
515 EXPECT_THAT(protocol->Start(1), IsOk());
516
517 auto parser = std::make_unique<MockCheckpointParser>();
518 EXPECT_CALL(*parser, GetTensor(StrEq("foo"))).WillOnce(Invoke([] {
519 return Tensor::Create(DT_FLOAT, {}, CreateTestData({2.f}));
520 }));
521
522 EXPECT_CALL(checkpoint_parser_factory_, Create(_))
523 .WillOnce(Return(ByMove(std::move(parser))));
524
525 EXPECT_CALL(callback_, OnCloseClient(Eq(0), IsCode(INVALID_ARGUMENT)));
526
527 // Receiving the client input should still succeed.
528 EXPECT_THAT(protocol->ReceiveClientMessage(0, MakeClientMessage()), IsOk());
529 EXPECT_THAT(
530 protocol->GetStatus(),
531 EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_failed: 1")));
532 }
533
TEST_F(SimpleAggregationProtocolTest,ReceiveClientMessage_UriType_Success)534 TEST_F(SimpleAggregationProtocolTest, ReceiveClientMessage_UriType_Success) {
535 auto protocol = CreateProtocolWithDefaultConfig();
536 EXPECT_CALL(callback_, OnAcceptClients);
537 EXPECT_THAT(protocol->Start(1), IsOk());
538 auto parser = std::make_unique<MockCheckpointParser>();
539 EXPECT_CALL(*parser, GetTensor(StrEq("foo"))).WillOnce(Invoke([] {
540 return Tensor::Create(DT_INT32, {}, CreateTestData({1}));
541 }));
542 EXPECT_CALL(checkpoint_parser_factory_, Create(_))
543 .WillOnce(Return(ByMove(std::move(parser))));
544
545 // Receive input for the client #0
546 EXPECT_CALL(resource_resolver_, RetrieveResource(0, StrEq("foo_uri")))
547 .WillOnce(Return(absl::Cord{}));
548 ClientMessage message;
549 message.mutable_simple_aggregation()->mutable_input()->set_uri("foo_uri");
550 EXPECT_CALL(callback_, OnCloseClient(Eq(0), IsCode(OK)));
551 EXPECT_THAT(protocol->ReceiveClientMessage(0, message), IsOk());
552 EXPECT_THAT(
553 protocol->GetStatus(),
554 EqualsProto<StatusMessage>(PARSE_TEXT_PROTO(
555 "num_clients_completed: 1 num_inputs_aggregated_and_included: 1")));
556 }
557
TEST_F(SimpleAggregationProtocolTest,ReceiveClientMessage_UriType_FailToParse)558 TEST_F(SimpleAggregationProtocolTest,
559 ReceiveClientMessage_UriType_FailToParse) {
560 auto protocol = CreateProtocolWithDefaultConfig();
561 EXPECT_CALL(callback_, OnAcceptClients);
562 EXPECT_THAT(protocol->Start(1), IsOk());
563
564 // Receive invalid input for the client #0
565 EXPECT_CALL(resource_resolver_, RetrieveResource(0, _))
566 .WillOnce(Return(absl::InvalidArgumentError("Invalid uri")));
567 ClientMessage message;
568 message.mutable_simple_aggregation()->mutable_input()->set_uri("foo_uri");
569 EXPECT_CALL(callback_, OnCloseClient(Eq(0), IsCode(INVALID_ARGUMENT)));
570
571 // Receiving the client input should still succeed.
572 EXPECT_THAT(protocol->ReceiveClientMessage(0, message), IsOk());
573 EXPECT_THAT(
574 protocol->GetStatus(),
575 EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_failed: 1")));
576 }
577
TEST_F(SimpleAggregationProtocolTest,Complete_NoInputsReceived)578 TEST_F(SimpleAggregationProtocolTest, Complete_NoInputsReceived) {
579 // Two intrinsics:
580 // 1) federated_sum "foo" that takes int32 {2,3} tensors.
581 // 2) federated_sum "bar" that takes scalar float tensors.
582 Configuration config_message = PARSE_TEXT_PROTO(R"pb(
583 aggregation_configs {
584 intrinsic_uri: "federated_sum"
585 intrinsic_args {
586 input_tensor {
587 name: "foo"
588 dtype: DT_INT32
589 shape {
590 dim { size: 2 }
591 dim { size: 3 }
592 }
593 }
594 }
595 output_tensors {
596 name: "foo_out"
597 dtype: DT_INT32
598 shape {
599 dim { size: 2 }
600 dim { size: 3 }
601 }
602 }
603 }
604 aggregation_configs {
605 intrinsic_uri: "federated_sum"
606 intrinsic_args {
607 input_tensor {
608 name: "bar"
609 dtype: DT_FLOAT
610 shape {}
611 }
612 }
613 output_tensors {
614 name: "bar_out"
615 dtype: DT_FLOAT
616 shape {}
617 }
618 }
619 )pb");
620 auto protocol = CreateProtocol(config_message);
621
622 EXPECT_CALL(callback_, OnAcceptClients(0, 1, _));
623 EXPECT_THAT(protocol->Start(1), IsOk());
624
625 // Verify that the checkpoint builder is created.
626 auto& checkpoint_builder = ExpectCheckpointBuilder();
627
628 // Verify that foo_out and bar_out tensors are added to the result checkpoint
629 EXPECT_CALL(checkpoint_builder,
630 Add(StrEq("foo_out"), IsTensor({2, 3}, {0, 0, 0, 0, 0, 0})))
631 .WillOnce(Return(absl::OkStatus()));
632 EXPECT_CALL(checkpoint_builder, Add(StrEq("bar_out"), IsTensor({}, {0.f})))
633 .WillOnce(Return(absl::OkStatus()));
634
635 EXPECT_THAT(
636 protocol->GetStatus(),
637 EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_pending: 1")));
638
639 // Verify that the pending client is closed.
640 EXPECT_CALL(callback_, OnCloseClient(0, IsCode(ABORTED)));
641 // Verify that the Complete callback method is called.
642 EXPECT_CALL(callback_, OnComplete);
643
644 EXPECT_THAT(protocol->Complete(), IsOk());
645 EXPECT_THAT(
646 protocol->GetStatus(),
647 EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_aborted: 1")));
648 }
649
TEST_F(SimpleAggregationProtocolTest,Complete_TwoInputsReceived)650 TEST_F(SimpleAggregationProtocolTest, Complete_TwoInputsReceived) {
651 // Two intrinsics:
652 // 1) federated_sum "foo" that takes int32 {2,3} tensors.
653 // 2) federated_sum "bar" that takes scalar float tensors.
654 Configuration config_message = PARSE_TEXT_PROTO(R"pb(
655 aggregation_configs {
656 intrinsic_uri: "federated_sum"
657 intrinsic_args {
658 input_tensor {
659 name: "foo"
660 dtype: DT_INT32
661 shape {
662 dim { size: 2 }
663 dim { size: 3 }
664 }
665 }
666 }
667 output_tensors {
668 name: "foo_out"
669 dtype: DT_INT32
670 shape {
671 dim { size: 2 }
672 dim { size: 3 }
673 }
674 }
675 }
676 aggregation_configs {
677 intrinsic_uri: "federated_sum"
678 intrinsic_args {
679 input_tensor {
680 name: "bar"
681 dtype: DT_FLOAT
682 shape {}
683 }
684 }
685 output_tensors {
686 name: "bar_out"
687 dtype: DT_FLOAT
688 shape {}
689 }
690 }
691 )pb");
692 auto protocol = CreateProtocol(config_message);
693 EXPECT_CALL(callback_, OnAcceptClients);
694 EXPECT_THAT(protocol->Start(2), IsOk());
695
696 // Expect two inputs.
697 auto parser1 = std::make_unique<MockCheckpointParser>();
698 EXPECT_CALL(*parser1, GetTensor(StrEq("foo"))).WillOnce(Invoke([] {
699 return Tensor::Create(DT_INT32, {2, 3},
700 CreateTestData({4, 3, 11, 7, 1, 6}));
701 }));
702 EXPECT_CALL(*parser1, GetTensor(StrEq("bar"))).WillOnce(Invoke([] {
703 return Tensor::Create(DT_FLOAT, {}, CreateTestData({1.f}));
704 }));
705
706 auto parser2 = std::make_unique<MockCheckpointParser>();
707 EXPECT_CALL(*parser2, GetTensor(StrEq("foo"))).WillOnce(Invoke([] {
708 return Tensor::Create(DT_INT32, {2, 3},
709 CreateTestData({1, 8, 2, 10, 13, 2}));
710 }));
711 EXPECT_CALL(*parser2, GetTensor(StrEq("bar"))).WillOnce(Invoke([] {
712 return Tensor::Create(DT_FLOAT, {}, CreateTestData({2.f}));
713 }));
714
715 EXPECT_CALL(checkpoint_parser_factory_, Create(_))
716 .WillOnce(Return(ByMove(std::move(parser1))))
717 .WillOnce(Return(ByMove(std::move(parser2))));
718
719 EXPECT_CALL(callback_, OnCloseClient(Eq(0), IsCode(OK)));
720 EXPECT_CALL(callback_, OnCloseClient(Eq(1), IsCode(OK)));
721
722 // Handle the inputs.
723 EXPECT_THAT(protocol->ReceiveClientMessage(0, MakeClientMessage()), IsOk());
724 EXPECT_THAT(protocol->GetStatus(),
725 EqualsProto<StatusMessage>(PARSE_TEXT_PROTO(
726 "num_clients_pending: 1 num_clients_completed: 1 "
727 "num_inputs_aggregated_and_included: 1")));
728
729 EXPECT_THAT(protocol->ReceiveClientMessage(1, MakeClientMessage()), IsOk());
730 EXPECT_THAT(
731 protocol->GetStatus(),
732 EqualsProto<StatusMessage>(PARSE_TEXT_PROTO(
733 "num_clients_completed: 2 num_inputs_aggregated_and_included: 2")));
734
735 // Complete the protocol.
736 // Verify that the checkpoint builder is created.
737 auto& checkpoint_builder = ExpectCheckpointBuilder();
738
739 // Verify that foo_out and bar_out tensors are added to the result checkpoint
740 EXPECT_CALL(checkpoint_builder,
741 Add(StrEq("foo_out"), IsTensor({2, 3}, {5, 11, 13, 17, 14, 8})))
742 .WillOnce(Return(absl::OkStatus()));
743 EXPECT_CALL(checkpoint_builder, Add(StrEq("bar_out"), IsTensor({}, {3.f})))
744 .WillOnce(Return(absl::OkStatus()));
745
746 // Verify that the OnComplete callback method is called.
747 EXPECT_CALL(callback_, OnComplete);
748
749 EXPECT_THAT(protocol->Complete(), IsOk());
750 EXPECT_THAT(
751 protocol->GetStatus(),
752 EqualsProto<StatusMessage>(PARSE_TEXT_PROTO(
753 "num_clients_completed: 2 num_inputs_aggregated_and_included: 2")));
754 }
755
TEST_F(SimpleAggregationProtocolTest,Complete_ProtocolNotStarted)756 TEST_F(SimpleAggregationProtocolTest, Complete_ProtocolNotStarted) {
757 auto protocol = CreateProtocolWithDefaultConfig();
758 EXPECT_THAT(protocol->Complete(), IsCode(FAILED_PRECONDITION));
759 }
760
TEST_F(SimpleAggregationProtocolTest,Abort_NoInputsReceived)761 TEST_F(SimpleAggregationProtocolTest, Abort_NoInputsReceived) {
762 auto protocol = CreateProtocolWithDefaultConfig();
763 EXPECT_CALL(callback_, OnAcceptClients(0, 2, _));
764 EXPECT_THAT(protocol->Start(2), IsOk());
765
766 EXPECT_CALL(callback_, OnCloseClient(Eq(0), IsCode(ABORTED)));
767 EXPECT_CALL(callback_, OnCloseClient(Eq(1), IsCode(ABORTED)));
768 EXPECT_THAT(protocol->Abort(), IsOk());
769 EXPECT_THAT(
770 protocol->GetStatus(),
771 EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_aborted: 2")));
772 }
773
TEST_F(SimpleAggregationProtocolTest,Abort_OneInputReceived)774 TEST_F(SimpleAggregationProtocolTest, Abort_OneInputReceived) {
775 auto protocol = CreateProtocolWithDefaultConfig();
776 EXPECT_CALL(callback_, OnAcceptClients(0, 2, _));
777 EXPECT_THAT(protocol->Start(2), IsOk());
778
779 auto parser = std::make_unique<MockCheckpointParser>();
780 EXPECT_CALL(*parser, GetTensor(StrEq("foo"))).WillOnce(Invoke([] {
781 return Tensor::Create(DT_INT32, {}, CreateTestData({1}));
782 }));
783
784 EXPECT_CALL(checkpoint_parser_factory_, Create(_))
785 .WillOnce(Return(ByMove(std::move(parser))));
786
787 // Receive input for the client #1
788 EXPECT_CALL(callback_, OnCloseClient(Eq(1), IsCode(OK)));
789 EXPECT_THAT(protocol->ReceiveClientMessage(1, MakeClientMessage()), IsOk());
790
791 // The client #0 should be aborted on Abort().
792 EXPECT_CALL(callback_, OnCloseClient(Eq(0), IsCode(ABORTED)));
793 EXPECT_THAT(protocol->Abort(), IsOk());
794 EXPECT_THAT(protocol->GetStatus(),
795 EqualsProto<StatusMessage>(PARSE_TEXT_PROTO(
796 "num_clients_aborted: 1 num_clients_completed:1 "
797 "num_inputs_discarded: 1")));
798 }
799
TEST_F(SimpleAggregationProtocolTest,Abort_ProtocolNotStarted)800 TEST_F(SimpleAggregationProtocolTest, Abort_ProtocolNotStarted) {
801 auto protocol = CreateProtocolWithDefaultConfig();
802 EXPECT_THAT(protocol->Abort(), IsCode(FAILED_PRECONDITION));
803 }
804
TEST_F(SimpleAggregationProtocolTest,ConcurrentAggregation_Success)805 TEST_F(SimpleAggregationProtocolTest, ConcurrentAggregation_Success) {
806 const int64_t kNumClients = 10;
807 auto protocol = CreateProtocolWithDefaultConfig();
808 EXPECT_CALL(callback_, OnAcceptClients(0, kNumClients, _));
809 EXPECT_THAT(protocol->Start(kNumClients), IsOk());
810
811 // The following block will repeatedly create CheckpointParser instances
812 // which will be creating scalar int tensors with repeatedly incrementing
813 // values.
814 std::atomic<int> tensor_value = 0;
815 EXPECT_CALL(checkpoint_parser_factory_, Create(_)).WillRepeatedly(Invoke([&] {
816 auto parser = std::make_unique<MockCheckpointParser>();
817 EXPECT_CALL(*parser, GetTensor(StrEq("foo"))).WillOnce(Invoke([&] {
818 return Tensor::Create(DT_INT32, {}, CreateTestData({++tensor_value}));
819 }));
820 return parser;
821 }));
822
823 // Schedule receiving inputs on 4 concurrent threads.
824 auto scheduler = CreateThreadPoolScheduler(4);
825 for (int64_t i = 0; i < kNumClients; ++i) {
826 scheduler->Schedule([&, i]() {
827 EXPECT_THAT(protocol->ReceiveClientMessage(i, MakeClientMessage()),
828 IsOk());
829 });
830 }
831 scheduler->WaitUntilIdle();
832
833 // Complete the protocol.
834 // Verify that the checkpoint builder is created.
835 auto& checkpoint_builder = ExpectCheckpointBuilder();
836 // Verify that foo_out tensor is added to the result checkpoint
837 EXPECT_CALL(checkpoint_builder, Add(StrEq("foo_out"), IsTensor({}, {55})))
838 .WillOnce(Return(absl::OkStatus()));
839
840 // Verify that the OnComplete callback method is called.
841 EXPECT_CALL(callback_, OnComplete);
842 EXPECT_THAT(protocol->Complete(), IsOk());
843 }
844
845 // A trivial test aggregator that delegates aggregation to a function.
846 class FunctionAggregator final : public AggVectorAggregator<int> {
847 public:
848 using Func = std::function<int(int, int)>;
849
FunctionAggregator(DataType dtype,TensorShape shape,Func agg_function)850 FunctionAggregator(DataType dtype, TensorShape shape, Func agg_function)
851 : AggVectorAggregator<int>(dtype, shape), agg_function_(agg_function) {}
852
853 private:
AggregateVector(const AggVector<int> & agg_vector)854 void AggregateVector(const AggVector<int>& agg_vector) override {
855 for (auto [i, v] : agg_vector) {
856 data()[i] = agg_function_(data()[i], v);
857 }
858 }
859
860 const Func agg_function_;
861 };
862
863 // Factory for the FunctionAggregator.
864 class FunctionAggregatorFactory final : public TensorAggregatorFactory {
865 public:
FunctionAggregatorFactory(FunctionAggregator::Func agg_function)866 explicit FunctionAggregatorFactory(FunctionAggregator::Func agg_function)
867 : agg_function_(agg_function) {}
868
869 private:
Create(DataType dtype,TensorShape shape) const870 absl::StatusOr<std::unique_ptr<TensorAggregator>> Create(
871 DataType dtype, TensorShape shape) const override {
872 if (dtype != DT_INT32) {
873 return absl::InvalidArgumentError("Unsupported dtype: expected DT_INT32");
874 }
875 return std::make_unique<FunctionAggregator>(dtype, shape, agg_function_);
876 }
877
878 const FunctionAggregator::Func agg_function_;
879 };
880
TEST_F(SimpleAggregationProtocolTest,ConcurrentAggregation_AbortWhileQueued)881 TEST_F(SimpleAggregationProtocolTest, ConcurrentAggregation_AbortWhileQueued) {
882 const int64_t kNumClients = 10;
883 const int64_t kNumClientBeforeBlocking = 3;
884
885 // Notifies the aggregation to unblock;
886 absl::Notification resume_aggregation_notification;
887 absl::Notification aggregation_blocked_notification;
888 std::atomic<int> agg_counter = 0;
889 FunctionAggregatorFactory agg_factory([&](int a, int b) {
890 if (++agg_counter > kNumClientBeforeBlocking &&
891 !aggregation_blocked_notification.HasBeenNotified()) {
892 aggregation_blocked_notification.Notify();
893 resume_aggregation_notification.WaitForNotification();
894 }
895 return a + b;
896 });
897 RegisterAggregatorFactory("foo1_aggregation", &agg_factory);
898
899 // The configuration below refers to the custom aggregation registered
900 // above.
901 auto protocol = CreateProtocol(PARSE_TEXT_PROTO(R"pb(
902 aggregation_configs {
903 intrinsic_uri: "foo1_aggregation"
904 intrinsic_args {
905 input_tensor {
906 name: "foo"
907 dtype: DT_INT32
908 shape {}
909 }
910 }
911 output_tensors {
912 name: "foo_out"
913 dtype: DT_INT32
914 shape {}
915 }
916 }
917 )pb"));
918 EXPECT_CALL(callback_, OnAcceptClients(0, kNumClients, _));
919 EXPECT_THAT(protocol->Start(kNumClients), IsOk());
920
921 EXPECT_CALL(checkpoint_parser_factory_, Create(_)).WillRepeatedly(Invoke([&] {
922 auto parser = std::make_unique<MockCheckpointParser>();
923 EXPECT_CALL(*parser, GetTensor(StrEq("foo"))).WillOnce(Invoke([&] {
924 return Tensor::Create(DT_INT32, {}, CreateTestData({1}));
925 }));
926 return parser;
927 }));
928
929 // Schedule receiving inputs on 10 concurrent threads.
930 auto scheduler = CreateThreadPoolScheduler(10);
931 for (int64_t i = 0; i < kNumClients; ++i) {
932 scheduler->Schedule([&, i]() {
933 EXPECT_THAT(protocol->ReceiveClientMessage(i, MakeClientMessage()),
934 IsOk());
935 });
936 }
937
938 aggregation_blocked_notification.WaitForNotification();
939
940 StatusMessage status_message;
941 do {
942 status_message = protocol->GetStatus();
943 } while (status_message.num_clients_pending() > 0);
944
945 // At this point one input must be blocked inside the aggregation waiting for
946 // the notification, 3 inputs should already be gone through the aggregation,
947 // and the remaining 6 inputs should be blocked waiting to enter the
948 // aggregation.
949
950 // TODO(team): Need to revise the status implementation because it
951 // treats received and pending (queued) inputs "as aggregated and pending".
952 EXPECT_THAT(protocol->GetStatus(),
953 EqualsProto<StatusMessage>(
954 PARSE_TEXT_PROTO("num_clients_completed: 10 "
955 "num_inputs_aggregated_and_pending: 7 "
956 "num_inputs_aggregated_and_included: 3")));
957
958 resume_aggregation_notification.Notify();
959
960 // Abort and let all blocked aggregations continue.
961 EXPECT_THAT(protocol->Abort(), IsOk());
962 scheduler->WaitUntilIdle();
963
964 // All 10 inputs should now be discarded.
965 EXPECT_THAT(
966 protocol->GetStatus(),
967 EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_completed: 10 "
968 "num_inputs_discarded: 10")));
969 }
970
971 } // namespace
972 } // namespace fcp::aggregation
973