1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/aggregation/protocol/simple_aggregation/simple_aggregation_protocol.h"
18 
19 #include <atomic>
20 #include <functional>
21 #include <memory>
22 #include <utility>
23 
24 #include "gmock/gmock.h"
25 #include "gtest/gtest.h"
26 #include "absl/status/status.h"
27 #include "absl/status/statusor.h"
28 #include "absl/strings/cord.h"
29 #include "absl/synchronization/notification.h"
30 #include "fcp/aggregation/core/agg_vector_aggregator.h"
31 #include "fcp/aggregation/core/tensor.h"
32 #include "fcp/aggregation/core/tensor_aggregator_factory.h"
33 #include "fcp/aggregation/core/tensor_aggregator_registry.h"
34 #include "fcp/aggregation/protocol/aggregation_protocol_messages.pb.h"
35 #include "fcp/aggregation/protocol/checkpoint_builder.h"
36 #include "fcp/aggregation/protocol/checkpoint_parser.h"
37 #include "fcp/aggregation/protocol/configuration.pb.h"
38 #include "fcp/aggregation/protocol/testing/test_callback.h"
39 #include "fcp/aggregation/testing/test_data.h"
40 #include "fcp/aggregation/testing/testing.h"
41 #include "fcp/base/monitoring.h"
42 #include "fcp/base/scheduler.h"
43 #include "fcp/testing/testing.h"
44 
45 namespace fcp::aggregation {
46 namespace {
47 
48 using ::testing::_;
49 using ::testing::ByMove;
50 using ::testing::Eq;
51 using ::testing::Invoke;
52 using ::testing::Return;
53 using ::testing::StrEq;
54 
55 // TODO(team): Consider moving mock classes into a separate test library.
56 class MockCheckpointParser : public CheckpointParser {
57  public:
58   MOCK_METHOD(absl::StatusOr<Tensor>, GetTensor, (const std::string& name),
59               (const override));
60 };
61 
62 class MockCheckpointParserFactory : public CheckpointParserFactory {
63  public:
64   MOCK_METHOD(absl::StatusOr<std::unique_ptr<CheckpointParser>>, Create,
65               (const absl::Cord& serialized_checkpoint), (const override));
66 };
67 
68 class MockCheckpointBuilder : public CheckpointBuilder {
69  public:
70   MOCK_METHOD(absl::Status, Add,
71               (const std::string& name, const Tensor& tensor), (override));
72   MOCK_METHOD(absl::StatusOr<absl::Cord>, Build, (), (override));
73 };
74 
75 class MockCheckpointBuilderFactory : public CheckpointBuilderFactory {
76  public:
77   MOCK_METHOD(std::unique_ptr<CheckpointBuilder>, Create, (), (const override));
78 };
79 
80 class MockResourceResolver : public ResourceResolver {
81  public:
82   MOCK_METHOD(absl::StatusOr<absl::Cord>, RetrieveResource,
83               (int64_t client_id, const std::string& uri), (override));
84 };
85 
86 class SimpleAggregationProtocolTest : public ::testing::Test {
87  protected:
88   // Returns default configuration.
89   Configuration default_configuration() const;
90 
91   // Returns the default instance of checkpoint bilder;
ExpectCheckpointBuilder()92   MockCheckpointBuilder& ExpectCheckpointBuilder() {
93     MockCheckpointBuilder& checkpoint_builder = *wrapped_checkpoint_builder_;
94     EXPECT_CALL(checkpoint_builder_factory_, Create())
95         .WillOnce(Return(ByMove(std::move(wrapped_checkpoint_builder_))));
96     EXPECT_CALL(checkpoint_builder, Build()).WillOnce(Return(absl::Cord{}));
97     return checkpoint_builder;
98   }
99 
100   // Creates an instance of SimpleAggregationProtocol with the specified config.
101   std::unique_ptr<SimpleAggregationProtocol> CreateProtocol(
102       Configuration config);
103 
104   // Creates an instance of SimpleAggregationProtocol with the default config.
CreateProtocolWithDefaultConfig()105   std::unique_ptr<SimpleAggregationProtocol> CreateProtocolWithDefaultConfig() {
106     return CreateProtocol(default_configuration());
107   }
108 
109   MockAggregationProtocolCallback callback_;
110 
111   MockCheckpointParserFactory checkpoint_parser_factory_;
112   MockCheckpointBuilderFactory checkpoint_builder_factory_;
113   MockResourceResolver resource_resolver_;
114 
115  private:
116   std::unique_ptr<MockCheckpointBuilder> wrapped_checkpoint_builder_ =
117       std::make_unique<MockCheckpointBuilder>();
118 };
119 
default_configuration() const120 Configuration SimpleAggregationProtocolTest::default_configuration() const {
121   // One "federated_sum" intrinsic with a single scalar int32 tensor.
122   return PARSE_TEXT_PROTO(R"pb(
123     aggregation_configs {
124       intrinsic_uri: "federated_sum"
125       intrinsic_args {
126         input_tensor {
127           name: "foo"
128           dtype: DT_INT32
129           shape {}
130         }
131       }
132       output_tensors {
133         name: "foo_out"
134         dtype: DT_INT32
135         shape {}
136       }
137     }
138   )pb");
139 }
140 
141 std::unique_ptr<SimpleAggregationProtocol>
CreateProtocol(Configuration config)142 SimpleAggregationProtocolTest::CreateProtocol(Configuration config) {
143   // Verify that the protocol can be created successfully.
144   absl::StatusOr<std::unique_ptr<SimpleAggregationProtocol>>
145       protocol_or_status = SimpleAggregationProtocol::Create(
146           config, &callback_, &checkpoint_parser_factory_,
147           &checkpoint_builder_factory_, &resource_resolver_);
148   EXPECT_THAT(protocol_or_status, IsOk());
149   return std::move(protocol_or_status).value();
150 }
151 
MakeClientMessage()152 ClientMessage MakeClientMessage() {
153   ClientMessage message;
154   message.mutable_simple_aggregation()->mutable_input()->set_inline_bytes("");
155   return message;
156 }
157 
TEST_F(SimpleAggregationProtocolTest,Create_UnsupportedNumberOfInputs)158 TEST_F(SimpleAggregationProtocolTest, Create_UnsupportedNumberOfInputs) {
159   Configuration config_message = PARSE_TEXT_PROTO(R"pb(
160     aggregation_configs {
161       intrinsic_uri: "federated_sum"
162       intrinsic_args {
163         input_tensor {
164           name: "foo"
165           dtype: DT_INT32
166           shape {}
167         }
168       }
169       intrinsic_args {
170         input_tensor {
171           name: "bar"
172           dtype: DT_INT32
173           shape {}
174         }
175       }
176       output_tensors {
177         name: "foo_out"
178         dtype: DT_INT32
179         shape {}
180       }
181     }
182   )pb");
183   EXPECT_THAT(SimpleAggregationProtocol::Create(
184                   config_message, &callback_, &checkpoint_parser_factory_,
185                   &checkpoint_builder_factory_, &resource_resolver_),
186               IsCode(INVALID_ARGUMENT));
187 }
188 
TEST_F(SimpleAggregationProtocolTest,Create_UnsupportedNumberOfOutputs)189 TEST_F(SimpleAggregationProtocolTest, Create_UnsupportedNumberOfOutputs) {
190   Configuration config_message = PARSE_TEXT_PROTO(R"pb(
191     aggregation_configs {
192       intrinsic_uri: "federated_sum"
193       intrinsic_args {
194         input_tensor {
195           name: "foo"
196           dtype: DT_INT32
197           shape {}
198         }
199       }
200       output_tensors {
201         name: "foo_out"
202         dtype: DT_INT32
203         shape {}
204       }
205       output_tensors {
206         name: "bar_out"
207         dtype: DT_INT32
208         shape {}
209       }
210     }
211   )pb");
212   EXPECT_THAT(SimpleAggregationProtocol::Create(
213                   config_message, &callback_, &checkpoint_parser_factory_,
214                   &checkpoint_builder_factory_, &resource_resolver_),
215               IsCode(INVALID_ARGUMENT));
216 }
217 
TEST_F(SimpleAggregationProtocolTest,Create_UnsupportedInputType)218 TEST_F(SimpleAggregationProtocolTest, Create_UnsupportedInputType) {
219   Configuration config_message = PARSE_TEXT_PROTO(R"pb(
220     aggregation_configs {
221       intrinsic_uri: "federated_sum"
222       intrinsic_args { parameter {} }
223       output_tensors {
224         name: "foo_out"
225         dtype: DT_INT32
226         shape {}
227       }
228     }
229   )pb");
230   EXPECT_THAT(SimpleAggregationProtocol::Create(
231                   config_message, &callback_, &checkpoint_parser_factory_,
232                   &checkpoint_builder_factory_, &resource_resolver_),
233               IsCode(INVALID_ARGUMENT));
234 }
235 
TEST_F(SimpleAggregationProtocolTest,Create_UnsupportedIntrinsicUri)236 TEST_F(SimpleAggregationProtocolTest, Create_UnsupportedIntrinsicUri) {
237   Configuration config_message = PARSE_TEXT_PROTO(R"pb(
238     aggregation_configs {
239       intrinsic_uri: "unsupported_xyz"
240       intrinsic_args {
241         input_tensor {
242           name: "foo"
243           dtype: DT_INT32
244           shape {}
245         }
246       }
247       output_tensors {
248         name: "foo_out"
249         dtype: DT_INT32
250         shape {}
251       }
252     }
253   )pb");
254   EXPECT_THAT(SimpleAggregationProtocol::Create(
255                   config_message, &callback_, &checkpoint_parser_factory_,
256                   &checkpoint_builder_factory_, &resource_resolver_),
257               IsCode(INVALID_ARGUMENT));
258 }
259 
TEST_F(SimpleAggregationProtocolTest,Create_UnsupportedInputSpec)260 TEST_F(SimpleAggregationProtocolTest, Create_UnsupportedInputSpec) {
261   Configuration config_message = PARSE_TEXT_PROTO(R"pb(
262     aggregation_configs {
263       intrinsic_uri: "federated_sum"
264       intrinsic_args {
265         input_tensor {
266           name: "foo"
267           dtype: DT_INT32
268           shape { dim { size: -1 } }
269         }
270       }
271       output_tensors {
272         name: "foo_out"
273         dtype: DT_INT32
274         shape {}
275       }
276     }
277   )pb");
278   EXPECT_THAT(SimpleAggregationProtocol::Create(
279                   config_message, &callback_, &checkpoint_parser_factory_,
280                   &checkpoint_builder_factory_, &resource_resolver_),
281               IsCode(INVALID_ARGUMENT));
282 }
283 
TEST_F(SimpleAggregationProtocolTest,Create_UnmatchingInputAndOutputDataType)284 TEST_F(SimpleAggregationProtocolTest, Create_UnmatchingInputAndOutputDataType) {
285   Configuration config_message = PARSE_TEXT_PROTO(R"pb(
286     aggregation_configs {
287       intrinsic_uri: "federated_sum"
288       intrinsic_args {
289         input_tensor {
290           name: "foo"
291           dtype: DT_INT32
292           shape {}
293         }
294       }
295       output_tensors {
296         name: "foo_out"
297         dtype: DT_FLOAT
298         shape {}
299       }
300     }
301   )pb");
302   EXPECT_THAT(SimpleAggregationProtocol::Create(
303                   config_message, &callback_, &checkpoint_parser_factory_,
304                   &checkpoint_builder_factory_, &resource_resolver_),
305               IsCode(INVALID_ARGUMENT));
306 }
307 
TEST_F(SimpleAggregationProtocolTest,Create_UnmatchingInputAndOutputShape)308 TEST_F(SimpleAggregationProtocolTest, Create_UnmatchingInputAndOutputShape) {
309   Configuration config_message = PARSE_TEXT_PROTO(R"pb(
310     aggregation_configs {
311       intrinsic_uri: "federated_sum"
312       intrinsic_args {
313         input_tensor {
314           name: "foo"
315           dtype: DT_INT32
316           shape { dim { size: 1 } }
317         }
318       }
319       output_tensors {
320         name: "foo_out"
321         dtype: DT_INT32
322         shape { dim { size: 2 } }
323       }
324     }
325   )pb");
326   EXPECT_THAT(SimpleAggregationProtocol::Create(
327                   config_message, &callback_, &checkpoint_parser_factory_,
328                   &checkpoint_builder_factory_, &resource_resolver_),
329               IsCode(INVALID_ARGUMENT));
330 }
331 
TEST_F(SimpleAggregationProtocolTest,StartProtocol_Success)332 TEST_F(SimpleAggregationProtocolTest, StartProtocol_Success) {
333   auto protocol = CreateProtocolWithDefaultConfig();
334   EXPECT_THAT(protocol->Start(3), IsOk());
335   EXPECT_THAT(
336       protocol->GetStatus(),
337       EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_pending: 3")));
338 }
339 
340 // TODO(team): Add similar tests for other callbacks.
TEST_F(SimpleAggregationProtocolTest,StartProtocol_AcceptClientsProtocolReentrace)341 TEST_F(SimpleAggregationProtocolTest,
342        StartProtocol_AcceptClientsProtocolReentrace) {
343   // This verifies that the protocol can be re-entered from the callback.
344   auto protocol = CreateProtocolWithDefaultConfig();
345 
346   EXPECT_CALL(callback_, OnAcceptClients(0, 1, _)).WillOnce(Invoke([&]() {
347     EXPECT_THAT(
348         protocol->GetStatus(),
349         EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_pending: 1")));
350   }));
351 
352   EXPECT_THAT(protocol->Start(1), IsOk());
353 }
354 
TEST_F(SimpleAggregationProtocolTest,StartProtocol_MultipleCalls)355 TEST_F(SimpleAggregationProtocolTest, StartProtocol_MultipleCalls) {
356   auto protocol = CreateProtocolWithDefaultConfig();
357   EXPECT_CALL(callback_, OnAcceptClients).Times(1);
358   EXPECT_THAT(protocol->Start(1), IsOk());
359   // The second Start call must fail.
360   EXPECT_THAT(protocol->Start(1), IsCode(FAILED_PRECONDITION));
361 }
362 
TEST_F(SimpleAggregationProtocolTest,StartProtocol_ZeroClients)363 TEST_F(SimpleAggregationProtocolTest, StartProtocol_ZeroClients) {
364   auto protocol = CreateProtocolWithDefaultConfig();
365   EXPECT_CALL(callback_, OnAcceptClients).Times(0);
366   EXPECT_THAT(protocol->Start(0), IsOk());
367 }
368 
TEST_F(SimpleAggregationProtocolTest,StartProtocol_NegativeClients)369 TEST_F(SimpleAggregationProtocolTest, StartProtocol_NegativeClients) {
370   auto protocol = CreateProtocolWithDefaultConfig();
371   EXPECT_CALL(callback_, OnAcceptClients).Times(0);
372   EXPECT_THAT(protocol->Start(-1), IsCode(INVALID_ARGUMENT));
373 }
374 
TEST_F(SimpleAggregationProtocolTest,AddClients_Success)375 TEST_F(SimpleAggregationProtocolTest, AddClients_Success) {
376   auto protocol = CreateProtocolWithDefaultConfig();
377 
378   EXPECT_CALL(callback_, OnAcceptClients(0, 1, _));
379   EXPECT_THAT(protocol->Start(1), IsOk());
380   EXPECT_THAT(
381       protocol->GetStatus(),
382       EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_pending: 1")));
383 
384   EXPECT_CALL(callback_, OnAcceptClients(1, 3, _));
385   EXPECT_THAT(protocol->AddClients(3), IsOk());
386   EXPECT_THAT(
387       protocol->GetStatus(),
388       EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_pending: 4")));
389 }
390 
TEST_F(SimpleAggregationProtocolTest,AddClients_ProtocolNotStarted)391 TEST_F(SimpleAggregationProtocolTest, AddClients_ProtocolNotStarted) {
392   auto protocol = CreateProtocolWithDefaultConfig();
393   // Must fail because the protocol isn't started.
394   EXPECT_CALL(callback_, OnAcceptClients).Times(0);
395   EXPECT_THAT(protocol->AddClients(1), IsCode(FAILED_PRECONDITION));
396 }
397 
TEST_F(SimpleAggregationProtocolTest,ReceiveClientMessage_ProtocolNotStarted)398 TEST_F(SimpleAggregationProtocolTest, ReceiveClientMessage_ProtocolNotStarted) {
399   auto protocol = CreateProtocolWithDefaultConfig();
400   // Must fail because the protocol isn't started.
401   EXPECT_THAT(protocol->ReceiveClientMessage(0, MakeClientMessage()),
402               IsCode(FAILED_PRECONDITION));
403 }
404 
TEST_F(SimpleAggregationProtocolTest,ReceiveClientMessage_InvalidMessage)405 TEST_F(SimpleAggregationProtocolTest, ReceiveClientMessage_InvalidMessage) {
406   auto protocol = CreateProtocolWithDefaultConfig();
407   ClientMessage message;
408   // Empty message without SimpleAggregation.
409   EXPECT_THAT(protocol->ReceiveClientMessage(0, message),
410               IsCode(INVALID_ARGUMENT));
411   // Message with SimpleAggregation but without the input.
412   message.mutable_simple_aggregation();
413   EXPECT_THAT(protocol->ReceiveClientMessage(0, message),
414               IsCode(INVALID_ARGUMENT));
415   // Message with empty input.
416   message.mutable_simple_aggregation()->mutable_input();
417   EXPECT_THAT(protocol->ReceiveClientMessage(0, message),
418               IsCode(INVALID_ARGUMENT));
419 }
420 
TEST_F(SimpleAggregationProtocolTest,ReceiveClientMessage_InvalidClientId)421 TEST_F(SimpleAggregationProtocolTest, ReceiveClientMessage_InvalidClientId) {
422   auto protocol = CreateProtocolWithDefaultConfig();
423   EXPECT_CALL(callback_, OnAcceptClients);
424   EXPECT_THAT(protocol->Start(1), IsOk());
425   // Must fail for the client_id -1 and 2.
426   EXPECT_THAT(protocol->ReceiveClientMessage(-1, MakeClientMessage()),
427               IsCode(INVALID_ARGUMENT));
428   EXPECT_THAT(protocol->ReceiveClientMessage(2, MakeClientMessage()),
429               IsCode(INVALID_ARGUMENT));
430 }
431 
TEST_F(SimpleAggregationProtocolTest,ReceiveClientMessage_DuplicateClientIdInputs)432 TEST_F(SimpleAggregationProtocolTest,
433        ReceiveClientMessage_DuplicateClientIdInputs) {
434   auto protocol = CreateProtocolWithDefaultConfig();
435   EXPECT_CALL(callback_, OnAcceptClients);
436   EXPECT_THAT(protocol->Start(2), IsOk());
437 
438   auto parser = std::make_unique<MockCheckpointParser>();
439   EXPECT_CALL(*parser, GetTensor(StrEq("foo"))).WillOnce(Invoke([] {
440     return Tensor::Create(DT_INT32, {}, CreateTestData({1}));
441   }));
442 
443   EXPECT_CALL(checkpoint_parser_factory_, Create(_))
444       .WillOnce(Return(ByMove(std::move(parser))));
445 
446   EXPECT_THAT(protocol->ReceiveClientMessage(0, MakeClientMessage()), IsOk());
447   // The second input for the same client must succeed to without changing the
448   // aggregated state.
449   EXPECT_THAT(protocol->ReceiveClientMessage(0, MakeClientMessage()), IsOk());
450   EXPECT_THAT(protocol->GetStatus(),
451               EqualsProto<StatusMessage>(PARSE_TEXT_PROTO(
452                   "num_clients_pending: 1 num_clients_completed: 1 "
453                   "num_inputs_aggregated_and_included: 1")));
454 }
455 
TEST_F(SimpleAggregationProtocolTest,ReceiveClientMessage_AfterClosingClient)456 TEST_F(SimpleAggregationProtocolTest, ReceiveClientMessage_AfterClosingClient) {
457   auto protocol = CreateProtocolWithDefaultConfig();
458   EXPECT_CALL(callback_, OnAcceptClients);
459   EXPECT_THAT(protocol->Start(1), IsOk());
460 
461   EXPECT_THAT(protocol->CloseClient(0, absl::OkStatus()), IsOk());
462   EXPECT_THAT(protocol->GetStatus(),
463               EqualsProto<StatusMessage>(PARSE_TEXT_PROTO(
464                   "num_clients_completed: 1 num_inputs_discarded: 1")));
465   // This must succeed to without changing the aggregated state.
466   EXPECT_THAT(protocol->ReceiveClientMessage(0, MakeClientMessage()), IsOk());
467   EXPECT_THAT(protocol->GetStatus(),
468               EqualsProto<StatusMessage>(PARSE_TEXT_PROTO(
469                   "num_clients_completed: 1 num_inputs_discarded: 1")));
470 }
471 
TEST_F(SimpleAggregationProtocolTest,ReceiveClientMessage_FailureToParseInput)472 TEST_F(SimpleAggregationProtocolTest,
473        ReceiveClientMessage_FailureToParseInput) {
474   auto protocol = CreateProtocolWithDefaultConfig();
475   EXPECT_CALL(callback_, OnAcceptClients);
476   EXPECT_THAT(protocol->Start(1), IsOk());
477 
478   EXPECT_CALL(checkpoint_parser_factory_, Create(_))
479       .WillOnce(
480           Return(ByMove(absl::InvalidArgumentError("Invalid checkpoint"))));
481 
482   EXPECT_CALL(callback_, OnCloseClient(Eq(0), IsCode(INVALID_ARGUMENT)));
483 
484   // Receiving the client input should still succeed.
485   EXPECT_THAT(protocol->ReceiveClientMessage(0, MakeClientMessage()), IsOk());
486   EXPECT_THAT(
487       protocol->GetStatus(),
488       EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_failed: 1")));
489 }
490 
TEST_F(SimpleAggregationProtocolTest,ReceiveClientMessage_MissingTensor)491 TEST_F(SimpleAggregationProtocolTest, ReceiveClientMessage_MissingTensor) {
492   auto protocol = CreateProtocolWithDefaultConfig();
493   EXPECT_CALL(callback_, OnAcceptClients);
494   EXPECT_THAT(protocol->Start(1), IsOk());
495 
496   auto parser = std::make_unique<MockCheckpointParser>();
497   EXPECT_CALL(*parser, GetTensor(StrEq("foo")))
498       .WillOnce(Return(ByMove(absl::NotFoundError("Missing tensor foo"))));
499 
500   EXPECT_CALL(checkpoint_parser_factory_, Create(_))
501       .WillOnce(Return(ByMove(std::move(parser))));
502 
503   EXPECT_CALL(callback_, OnCloseClient(Eq(0), IsCode(NOT_FOUND)));
504 
505   // Receiving the client input should still succeed.
506   EXPECT_THAT(protocol->ReceiveClientMessage(0, MakeClientMessage()), IsOk());
507   EXPECT_THAT(
508       protocol->GetStatus(),
509       EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_failed: 1")));
510 }
511 
TEST_F(SimpleAggregationProtocolTest,ReceiveClientMessage_MismatchingTensor)512 TEST_F(SimpleAggregationProtocolTest, ReceiveClientMessage_MismatchingTensor) {
513   auto protocol = CreateProtocolWithDefaultConfig();
514   EXPECT_CALL(callback_, OnAcceptClients);
515   EXPECT_THAT(protocol->Start(1), IsOk());
516 
517   auto parser = std::make_unique<MockCheckpointParser>();
518   EXPECT_CALL(*parser, GetTensor(StrEq("foo"))).WillOnce(Invoke([] {
519     return Tensor::Create(DT_FLOAT, {}, CreateTestData({2.f}));
520   }));
521 
522   EXPECT_CALL(checkpoint_parser_factory_, Create(_))
523       .WillOnce(Return(ByMove(std::move(parser))));
524 
525   EXPECT_CALL(callback_, OnCloseClient(Eq(0), IsCode(INVALID_ARGUMENT)));
526 
527   // Receiving the client input should still succeed.
528   EXPECT_THAT(protocol->ReceiveClientMessage(0, MakeClientMessage()), IsOk());
529   EXPECT_THAT(
530       protocol->GetStatus(),
531       EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_failed: 1")));
532 }
533 
TEST_F(SimpleAggregationProtocolTest,ReceiveClientMessage_UriType_Success)534 TEST_F(SimpleAggregationProtocolTest, ReceiveClientMessage_UriType_Success) {
535   auto protocol = CreateProtocolWithDefaultConfig();
536   EXPECT_CALL(callback_, OnAcceptClients);
537   EXPECT_THAT(protocol->Start(1), IsOk());
538   auto parser = std::make_unique<MockCheckpointParser>();
539   EXPECT_CALL(*parser, GetTensor(StrEq("foo"))).WillOnce(Invoke([] {
540     return Tensor::Create(DT_INT32, {}, CreateTestData({1}));
541   }));
542   EXPECT_CALL(checkpoint_parser_factory_, Create(_))
543       .WillOnce(Return(ByMove(std::move(parser))));
544 
545   // Receive input for the client #0
546   EXPECT_CALL(resource_resolver_, RetrieveResource(0, StrEq("foo_uri")))
547       .WillOnce(Return(absl::Cord{}));
548   ClientMessage message;
549   message.mutable_simple_aggregation()->mutable_input()->set_uri("foo_uri");
550   EXPECT_CALL(callback_, OnCloseClient(Eq(0), IsCode(OK)));
551   EXPECT_THAT(protocol->ReceiveClientMessage(0, message), IsOk());
552   EXPECT_THAT(
553       protocol->GetStatus(),
554       EqualsProto<StatusMessage>(PARSE_TEXT_PROTO(
555           "num_clients_completed: 1 num_inputs_aggregated_and_included: 1")));
556 }
557 
TEST_F(SimpleAggregationProtocolTest,ReceiveClientMessage_UriType_FailToParse)558 TEST_F(SimpleAggregationProtocolTest,
559        ReceiveClientMessage_UriType_FailToParse) {
560   auto protocol = CreateProtocolWithDefaultConfig();
561   EXPECT_CALL(callback_, OnAcceptClients);
562   EXPECT_THAT(protocol->Start(1), IsOk());
563 
564   // Receive invalid input for the client #0
565   EXPECT_CALL(resource_resolver_, RetrieveResource(0, _))
566       .WillOnce(Return(absl::InvalidArgumentError("Invalid uri")));
567   ClientMessage message;
568   message.mutable_simple_aggregation()->mutable_input()->set_uri("foo_uri");
569   EXPECT_CALL(callback_, OnCloseClient(Eq(0), IsCode(INVALID_ARGUMENT)));
570 
571   // Receiving the client input should still succeed.
572   EXPECT_THAT(protocol->ReceiveClientMessage(0, message), IsOk());
573   EXPECT_THAT(
574       protocol->GetStatus(),
575       EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_failed: 1")));
576 }
577 
TEST_F(SimpleAggregationProtocolTest,Complete_NoInputsReceived)578 TEST_F(SimpleAggregationProtocolTest, Complete_NoInputsReceived) {
579   // Two intrinsics:
580   // 1) federated_sum "foo" that takes int32 {2,3} tensors.
581   // 2) federated_sum "bar" that takes scalar float tensors.
582   Configuration config_message = PARSE_TEXT_PROTO(R"pb(
583     aggregation_configs {
584       intrinsic_uri: "federated_sum"
585       intrinsic_args {
586         input_tensor {
587           name: "foo"
588           dtype: DT_INT32
589           shape {
590             dim { size: 2 }
591             dim { size: 3 }
592           }
593         }
594       }
595       output_tensors {
596         name: "foo_out"
597         dtype: DT_INT32
598         shape {
599           dim { size: 2 }
600           dim { size: 3 }
601         }
602       }
603     }
604     aggregation_configs {
605       intrinsic_uri: "federated_sum"
606       intrinsic_args {
607         input_tensor {
608           name: "bar"
609           dtype: DT_FLOAT
610           shape {}
611         }
612       }
613       output_tensors {
614         name: "bar_out"
615         dtype: DT_FLOAT
616         shape {}
617       }
618     }
619   )pb");
620   auto protocol = CreateProtocol(config_message);
621 
622   EXPECT_CALL(callback_, OnAcceptClients(0, 1, _));
623   EXPECT_THAT(protocol->Start(1), IsOk());
624 
625   // Verify that the checkpoint builder is created.
626   auto& checkpoint_builder = ExpectCheckpointBuilder();
627 
628   // Verify that foo_out and bar_out tensors are added to the result checkpoint
629   EXPECT_CALL(checkpoint_builder,
630               Add(StrEq("foo_out"), IsTensor({2, 3}, {0, 0, 0, 0, 0, 0})))
631       .WillOnce(Return(absl::OkStatus()));
632   EXPECT_CALL(checkpoint_builder, Add(StrEq("bar_out"), IsTensor({}, {0.f})))
633       .WillOnce(Return(absl::OkStatus()));
634 
635   EXPECT_THAT(
636       protocol->GetStatus(),
637       EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_pending: 1")));
638 
639   // Verify that the pending client is closed.
640   EXPECT_CALL(callback_, OnCloseClient(0, IsCode(ABORTED)));
641   // Verify that the Complete callback method is called.
642   EXPECT_CALL(callback_, OnComplete);
643 
644   EXPECT_THAT(protocol->Complete(), IsOk());
645   EXPECT_THAT(
646       protocol->GetStatus(),
647       EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_aborted: 1")));
648 }
649 
TEST_F(SimpleAggregationProtocolTest,Complete_TwoInputsReceived)650 TEST_F(SimpleAggregationProtocolTest, Complete_TwoInputsReceived) {
651   // Two intrinsics:
652   // 1) federated_sum "foo" that takes int32 {2,3} tensors.
653   // 2) federated_sum "bar" that takes scalar float tensors.
654   Configuration config_message = PARSE_TEXT_PROTO(R"pb(
655     aggregation_configs {
656       intrinsic_uri: "federated_sum"
657       intrinsic_args {
658         input_tensor {
659           name: "foo"
660           dtype: DT_INT32
661           shape {
662             dim { size: 2 }
663             dim { size: 3 }
664           }
665         }
666       }
667       output_tensors {
668         name: "foo_out"
669         dtype: DT_INT32
670         shape {
671           dim { size: 2 }
672           dim { size: 3 }
673         }
674       }
675     }
676     aggregation_configs {
677       intrinsic_uri: "federated_sum"
678       intrinsic_args {
679         input_tensor {
680           name: "bar"
681           dtype: DT_FLOAT
682           shape {}
683         }
684       }
685       output_tensors {
686         name: "bar_out"
687         dtype: DT_FLOAT
688         shape {}
689       }
690     }
691   )pb");
692   auto protocol = CreateProtocol(config_message);
693   EXPECT_CALL(callback_, OnAcceptClients);
694   EXPECT_THAT(protocol->Start(2), IsOk());
695 
696   // Expect two inputs.
697   auto parser1 = std::make_unique<MockCheckpointParser>();
698   EXPECT_CALL(*parser1, GetTensor(StrEq("foo"))).WillOnce(Invoke([] {
699     return Tensor::Create(DT_INT32, {2, 3},
700                           CreateTestData({4, 3, 11, 7, 1, 6}));
701   }));
702   EXPECT_CALL(*parser1, GetTensor(StrEq("bar"))).WillOnce(Invoke([] {
703     return Tensor::Create(DT_FLOAT, {}, CreateTestData({1.f}));
704   }));
705 
706   auto parser2 = std::make_unique<MockCheckpointParser>();
707   EXPECT_CALL(*parser2, GetTensor(StrEq("foo"))).WillOnce(Invoke([] {
708     return Tensor::Create(DT_INT32, {2, 3},
709                           CreateTestData({1, 8, 2, 10, 13, 2}));
710   }));
711   EXPECT_CALL(*parser2, GetTensor(StrEq("bar"))).WillOnce(Invoke([] {
712     return Tensor::Create(DT_FLOAT, {}, CreateTestData({2.f}));
713   }));
714 
715   EXPECT_CALL(checkpoint_parser_factory_, Create(_))
716       .WillOnce(Return(ByMove(std::move(parser1))))
717       .WillOnce(Return(ByMove(std::move(parser2))));
718 
719   EXPECT_CALL(callback_, OnCloseClient(Eq(0), IsCode(OK)));
720   EXPECT_CALL(callback_, OnCloseClient(Eq(1), IsCode(OK)));
721 
722   // Handle the inputs.
723   EXPECT_THAT(protocol->ReceiveClientMessage(0, MakeClientMessage()), IsOk());
724   EXPECT_THAT(protocol->GetStatus(),
725               EqualsProto<StatusMessage>(PARSE_TEXT_PROTO(
726                   "num_clients_pending: 1 num_clients_completed: 1 "
727                   "num_inputs_aggregated_and_included: 1")));
728 
729   EXPECT_THAT(protocol->ReceiveClientMessage(1, MakeClientMessage()), IsOk());
730   EXPECT_THAT(
731       protocol->GetStatus(),
732       EqualsProto<StatusMessage>(PARSE_TEXT_PROTO(
733           "num_clients_completed: 2 num_inputs_aggregated_and_included: 2")));
734 
735   // Complete the protocol.
736   // Verify that the checkpoint builder is created.
737   auto& checkpoint_builder = ExpectCheckpointBuilder();
738 
739   // Verify that foo_out and bar_out tensors are added to the result checkpoint
740   EXPECT_CALL(checkpoint_builder,
741               Add(StrEq("foo_out"), IsTensor({2, 3}, {5, 11, 13, 17, 14, 8})))
742       .WillOnce(Return(absl::OkStatus()));
743   EXPECT_CALL(checkpoint_builder, Add(StrEq("bar_out"), IsTensor({}, {3.f})))
744       .WillOnce(Return(absl::OkStatus()));
745 
746   // Verify that the OnComplete callback method is called.
747   EXPECT_CALL(callback_, OnComplete);
748 
749   EXPECT_THAT(protocol->Complete(), IsOk());
750   EXPECT_THAT(
751       protocol->GetStatus(),
752       EqualsProto<StatusMessage>(PARSE_TEXT_PROTO(
753           "num_clients_completed: 2 num_inputs_aggregated_and_included: 2")));
754 }
755 
TEST_F(SimpleAggregationProtocolTest,Complete_ProtocolNotStarted)756 TEST_F(SimpleAggregationProtocolTest, Complete_ProtocolNotStarted) {
757   auto protocol = CreateProtocolWithDefaultConfig();
758   EXPECT_THAT(protocol->Complete(), IsCode(FAILED_PRECONDITION));
759 }
760 
TEST_F(SimpleAggregationProtocolTest,Abort_NoInputsReceived)761 TEST_F(SimpleAggregationProtocolTest, Abort_NoInputsReceived) {
762   auto protocol = CreateProtocolWithDefaultConfig();
763   EXPECT_CALL(callback_, OnAcceptClients(0, 2, _));
764   EXPECT_THAT(protocol->Start(2), IsOk());
765 
766   EXPECT_CALL(callback_, OnCloseClient(Eq(0), IsCode(ABORTED)));
767   EXPECT_CALL(callback_, OnCloseClient(Eq(1), IsCode(ABORTED)));
768   EXPECT_THAT(protocol->Abort(), IsOk());
769   EXPECT_THAT(
770       protocol->GetStatus(),
771       EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_aborted: 2")));
772 }
773 
TEST_F(SimpleAggregationProtocolTest,Abort_OneInputReceived)774 TEST_F(SimpleAggregationProtocolTest, Abort_OneInputReceived) {
775   auto protocol = CreateProtocolWithDefaultConfig();
776   EXPECT_CALL(callback_, OnAcceptClients(0, 2, _));
777   EXPECT_THAT(protocol->Start(2), IsOk());
778 
779   auto parser = std::make_unique<MockCheckpointParser>();
780   EXPECT_CALL(*parser, GetTensor(StrEq("foo"))).WillOnce(Invoke([] {
781     return Tensor::Create(DT_INT32, {}, CreateTestData({1}));
782   }));
783 
784   EXPECT_CALL(checkpoint_parser_factory_, Create(_))
785       .WillOnce(Return(ByMove(std::move(parser))));
786 
787   // Receive input for the client #1
788   EXPECT_CALL(callback_, OnCloseClient(Eq(1), IsCode(OK)));
789   EXPECT_THAT(protocol->ReceiveClientMessage(1, MakeClientMessage()), IsOk());
790 
791   // The client #0 should be aborted on Abort().
792   EXPECT_CALL(callback_, OnCloseClient(Eq(0), IsCode(ABORTED)));
793   EXPECT_THAT(protocol->Abort(), IsOk());
794   EXPECT_THAT(protocol->GetStatus(),
795               EqualsProto<StatusMessage>(PARSE_TEXT_PROTO(
796                   "num_clients_aborted: 1 num_clients_completed:1 "
797                   "num_inputs_discarded: 1")));
798 }
799 
TEST_F(SimpleAggregationProtocolTest,Abort_ProtocolNotStarted)800 TEST_F(SimpleAggregationProtocolTest, Abort_ProtocolNotStarted) {
801   auto protocol = CreateProtocolWithDefaultConfig();
802   EXPECT_THAT(protocol->Abort(), IsCode(FAILED_PRECONDITION));
803 }
804 
TEST_F(SimpleAggregationProtocolTest,ConcurrentAggregation_Success)805 TEST_F(SimpleAggregationProtocolTest, ConcurrentAggregation_Success) {
806   const int64_t kNumClients = 10;
807   auto protocol = CreateProtocolWithDefaultConfig();
808   EXPECT_CALL(callback_, OnAcceptClients(0, kNumClients, _));
809   EXPECT_THAT(protocol->Start(kNumClients), IsOk());
810 
811   // The following block will repeatedly create CheckpointParser instances
812   // which will be creating scalar int tensors with repeatedly incrementing
813   // values.
814   std::atomic<int> tensor_value = 0;
815   EXPECT_CALL(checkpoint_parser_factory_, Create(_)).WillRepeatedly(Invoke([&] {
816     auto parser = std::make_unique<MockCheckpointParser>();
817     EXPECT_CALL(*parser, GetTensor(StrEq("foo"))).WillOnce(Invoke([&] {
818       return Tensor::Create(DT_INT32, {}, CreateTestData({++tensor_value}));
819     }));
820     return parser;
821   }));
822 
823   // Schedule receiving inputs on 4 concurrent threads.
824   auto scheduler = CreateThreadPoolScheduler(4);
825   for (int64_t i = 0; i < kNumClients; ++i) {
826     scheduler->Schedule([&, i]() {
827       EXPECT_THAT(protocol->ReceiveClientMessage(i, MakeClientMessage()),
828                   IsOk());
829     });
830   }
831   scheduler->WaitUntilIdle();
832 
833   // Complete the protocol.
834   // Verify that the checkpoint builder is created.
835   auto& checkpoint_builder = ExpectCheckpointBuilder();
836   // Verify that foo_out tensor is added to the result checkpoint
837   EXPECT_CALL(checkpoint_builder, Add(StrEq("foo_out"), IsTensor({}, {55})))
838       .WillOnce(Return(absl::OkStatus()));
839 
840   // Verify that the OnComplete callback method is called.
841   EXPECT_CALL(callback_, OnComplete);
842   EXPECT_THAT(protocol->Complete(), IsOk());
843 }
844 
845 // A trivial test aggregator that delegates aggregation to a function.
846 class FunctionAggregator final : public AggVectorAggregator<int> {
847  public:
848   using Func = std::function<int(int, int)>;
849 
FunctionAggregator(DataType dtype,TensorShape shape,Func agg_function)850   FunctionAggregator(DataType dtype, TensorShape shape, Func agg_function)
851       : AggVectorAggregator<int>(dtype, shape), agg_function_(agg_function) {}
852 
853  private:
AggregateVector(const AggVector<int> & agg_vector)854   void AggregateVector(const AggVector<int>& agg_vector) override {
855     for (auto [i, v] : agg_vector) {
856       data()[i] = agg_function_(data()[i], v);
857     }
858   }
859 
860   const Func agg_function_;
861 };
862 
863 // Factory for the FunctionAggregator.
864 class FunctionAggregatorFactory final : public TensorAggregatorFactory {
865  public:
FunctionAggregatorFactory(FunctionAggregator::Func agg_function)866   explicit FunctionAggregatorFactory(FunctionAggregator::Func agg_function)
867       : agg_function_(agg_function) {}
868 
869  private:
Create(DataType dtype,TensorShape shape) const870   absl::StatusOr<std::unique_ptr<TensorAggregator>> Create(
871       DataType dtype, TensorShape shape) const override {
872     if (dtype != DT_INT32) {
873       return absl::InvalidArgumentError("Unsupported dtype: expected DT_INT32");
874     }
875     return std::make_unique<FunctionAggregator>(dtype, shape, agg_function_);
876   }
877 
878   const FunctionAggregator::Func agg_function_;
879 };
880 
TEST_F(SimpleAggregationProtocolTest,ConcurrentAggregation_AbortWhileQueued)881 TEST_F(SimpleAggregationProtocolTest, ConcurrentAggregation_AbortWhileQueued) {
882   const int64_t kNumClients = 10;
883   const int64_t kNumClientBeforeBlocking = 3;
884 
885   // Notifies the aggregation to unblock;
886   absl::Notification resume_aggregation_notification;
887   absl::Notification aggregation_blocked_notification;
888   std::atomic<int> agg_counter = 0;
889   FunctionAggregatorFactory agg_factory([&](int a, int b) {
890     if (++agg_counter > kNumClientBeforeBlocking &&
891         !aggregation_blocked_notification.HasBeenNotified()) {
892       aggregation_blocked_notification.Notify();
893       resume_aggregation_notification.WaitForNotification();
894     }
895     return a + b;
896   });
897   RegisterAggregatorFactory("foo1_aggregation", &agg_factory);
898 
899   // The configuration below refers to the custom aggregation registered
900   // above.
901   auto protocol = CreateProtocol(PARSE_TEXT_PROTO(R"pb(
902     aggregation_configs {
903       intrinsic_uri: "foo1_aggregation"
904       intrinsic_args {
905         input_tensor {
906           name: "foo"
907           dtype: DT_INT32
908           shape {}
909         }
910       }
911       output_tensors {
912         name: "foo_out"
913         dtype: DT_INT32
914         shape {}
915       }
916     }
917   )pb"));
918   EXPECT_CALL(callback_, OnAcceptClients(0, kNumClients, _));
919   EXPECT_THAT(protocol->Start(kNumClients), IsOk());
920 
921   EXPECT_CALL(checkpoint_parser_factory_, Create(_)).WillRepeatedly(Invoke([&] {
922     auto parser = std::make_unique<MockCheckpointParser>();
923     EXPECT_CALL(*parser, GetTensor(StrEq("foo"))).WillOnce(Invoke([&] {
924       return Tensor::Create(DT_INT32, {}, CreateTestData({1}));
925     }));
926     return parser;
927   }));
928 
929   // Schedule receiving inputs on 10 concurrent threads.
930   auto scheduler = CreateThreadPoolScheduler(10);
931   for (int64_t i = 0; i < kNumClients; ++i) {
932     scheduler->Schedule([&, i]() {
933       EXPECT_THAT(protocol->ReceiveClientMessage(i, MakeClientMessage()),
934                   IsOk());
935     });
936   }
937 
938   aggregation_blocked_notification.WaitForNotification();
939 
940   StatusMessage status_message;
941   do {
942     status_message = protocol->GetStatus();
943   } while (status_message.num_clients_pending() > 0);
944 
945   // At this point one input must be blocked inside the aggregation waiting for
946   // the notification, 3 inputs should already be gone through the aggregation,
947   // and the remaining 6 inputs should be blocked waiting to enter the
948   // aggregation.
949 
950   // TODO(team): Need to revise the status implementation because it
951   // treats received and pending (queued) inputs "as aggregated and pending".
952   EXPECT_THAT(protocol->GetStatus(),
953               EqualsProto<StatusMessage>(
954                   PARSE_TEXT_PROTO("num_clients_completed: 10 "
955                                    "num_inputs_aggregated_and_pending: 7 "
956                                    "num_inputs_aggregated_and_included: 3")));
957 
958   resume_aggregation_notification.Notify();
959 
960   // Abort and let all blocked aggregations continue.
961   EXPECT_THAT(protocol->Abort(), IsOk());
962   scheduler->WaitUntilIdle();
963 
964   // All 10 inputs should now be discarded.
965   EXPECT_THAT(
966       protocol->GetStatus(),
967       EqualsProto<StatusMessage>(PARSE_TEXT_PROTO("num_clients_completed: 10 "
968                                                   "num_inputs_discarded: 10")));
969 }
970 
971 }  // namespace
972 }  // namespace fcp::aggregation
973