1 /*
2 * Copyright 2020 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "fcp/client/grpc_federated_protocol.h"
17
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <tuple>
22 #include <utility>
23
24 #include "google/protobuf/text_format.h"
25 #include "gmock/gmock.h"
26 #include "gtest/gtest.h"
27 #include "absl/random/random.h"
28 #include "absl/status/status.h"
29 #include "absl/synchronization/blocking_counter.h"
30 #include "absl/time/time.h"
31 #include "fcp/base/monitoring.h"
32 #include "fcp/client/cache/test_helpers.h"
33 #include "fcp/client/diag_codes.pb.h"
34 #include "fcp/client/engine/engine.pb.h"
35 #include "fcp/client/grpc_bidi_stream.h"
36 #include "fcp/client/http/http_client.h"
37 #include "fcp/client/http/testing/test_helpers.h"
38 #include "fcp/client/interruptible_runner.h"
39 #include "fcp/client/stats.h"
40 #include "fcp/client/test_helpers.h"
41 #include "fcp/protos/federated_api.pb.h"
42 #include "fcp/secagg/client/secagg_client.h"
43 #include "fcp/secagg/shared/aes_ctr_prng_factory.h"
44 #include "fcp/secagg/testing/fake_prng.h"
45 #include "fcp/secagg/testing/mock_send_to_server_interface.h"
46 #include "fcp/secagg/testing/mock_state_transition_listener.h"
47 #include "fcp/testing/testing.h"
48
49 namespace fcp::client {
50 namespace {
51
52 using ::fcp::EqualsProto;
53 using ::fcp::IsCode;
54 using ::fcp::client::http::FakeHttpResponse;
55 using ::fcp::client::http::HttpRequest;
56 using ::fcp::client::http::MockHttpClient;
57 using ::fcp::client::http::SimpleHttpRequestMatcher;
58 using ::google::internal::federatedml::v2::AcceptanceInfo;
59 using ::google::internal::federatedml::v2::CheckinRequest;
60 using ::google::internal::federatedml::v2::ClientStreamMessage;
61 using ::google::internal::federatedml::v2::EligibilityEvalCheckinRequest;
62 using ::google::internal::federatedml::v2::EligibilityEvalPayload;
63 using ::google::internal::federatedml::v2::HttpCompressionFormat;
64 using ::google::internal::federatedml::v2::ReportResponse;
65 using ::google::internal::federatedml::v2::RetryWindow;
66 using ::google::internal::federatedml::v2::ServerStreamMessage;
67 using ::google::internal::federatedml::v2::TaskEligibilityInfo;
68 using ::google::internal::federatedml::v2::TaskWeight;
69 using ::testing::_;
70 using ::testing::AllOf;
71 using ::testing::DoAll;
72 using ::testing::DoubleEq;
73 using ::testing::DoubleNear;
74 using ::testing::Eq;
75 using ::testing::Field;
76 using ::testing::FieldsAre;
77 using ::testing::Ge;
78 using ::testing::Gt;
79 using ::testing::HasSubstr;
80 using ::testing::InSequence;
81 using ::testing::IsEmpty;
82 using ::testing::Lt;
83 using ::testing::MockFunction;
84 using ::testing::NiceMock;
85 using ::testing::Not;
86 using ::testing::Optional;
87 using ::testing::Pair;
88 using ::testing::Pointee;
89 using ::testing::Return;
90 using ::testing::SetArgPointee;
91 using ::testing::StrictMock;
92 using ::testing::UnorderedElementsAre;
93 using ::testing::VariantWith;
94
95 constexpr char kPopulationName[] = "TEST/POPULATION";
96 constexpr char kFederatedSelectUriTemplate[] = "https://federated.select";
97 constexpr char kExecutionPhaseId[] = "TEST/POPULATION/TEST_TASK#1234.ab35";
98 constexpr char kPlan[] = "CLIENT_ONLY_PLAN";
99 constexpr char kInitCheckpoint[] = "INIT_CHECKPOINT";
100 constexpr char kRetryToken[] = "OLD_RETRY_TOKEN";
101 constexpr char kClientVersion[] = "CLIENT_VERSION";
102 constexpr char kAttestationMeasurement[] = "ATTESTATION_MEASUREMENT";
103 constexpr int kSecAggExpectedNumberOfClients = 10;
104 constexpr int kSecAggMinSurvivingClientsForReconstruction = 8;
105 constexpr int kSecAggMinClientsInServerVisibleAggregate = 4;
106
107 class MockGrpcBidiStream : public GrpcBidiStreamInterface {
108 public:
109 MOCK_METHOD(absl::Status, Send, (ClientStreamMessage*), (override));
110 MOCK_METHOD(absl::Status, Receive, (ServerStreamMessage*), (override));
111 MOCK_METHOD(void, Close, (), (override));
112 MOCK_METHOD(int64_t, ChunkingLayerBytesSent, (), (override));
113 MOCK_METHOD(int64_t, ChunkingLayerBytesReceived, (), (override));
114 };
115
116 constexpr int kTransientErrorsRetryPeriodSecs = 10;
117 constexpr double kTransientErrorsRetryDelayJitterPercent = 0.1;
118 constexpr double kExpectedTransientErrorsRetryPeriodSecsMin = 9.0;
119 constexpr double kExpectedTransientErrorsRetryPeriodSecsMax = 11.0;
120 constexpr int kPermanentErrorsRetryPeriodSecs = 100;
121 constexpr double kPermanentErrorsRetryDelayJitterPercent = 0.2;
122 constexpr double kExpectedPermanentErrorsRetryPeriodSecsMin = 80.0;
123 constexpr double kExpectedPermanentErrorsRetryPeriodSecsMax = 120.0;
124
ExpectTransientErrorRetryWindow(const RetryWindow & retry_window)125 void ExpectTransientErrorRetryWindow(const RetryWindow& retry_window) {
126 // The calculated retry delay must lie within the expected transient errors
127 // retry delay range.
128 EXPECT_THAT(retry_window.delay_min().seconds() +
129 retry_window.delay_min().nanos() / 1000000000,
130 AllOf(Ge(kExpectedTransientErrorsRetryPeriodSecsMin),
131 Lt(kExpectedTransientErrorsRetryPeriodSecsMax)));
132 EXPECT_THAT(retry_window.delay_max(), EqualsProto(retry_window.delay_min()));
133 }
134
ExpectPermanentErrorRetryWindow(const RetryWindow & retry_window)135 void ExpectPermanentErrorRetryWindow(const RetryWindow& retry_window) {
136 // The calculated retry delay must lie within the expected permanent errors
137 // retry delay range.
138 EXPECT_THAT(retry_window.delay_min().seconds() +
139 retry_window.delay_min().nanos() / 1000000000,
140 AllOf(Ge(kExpectedPermanentErrorsRetryPeriodSecsMin),
141 Lt(kExpectedPermanentErrorsRetryPeriodSecsMax)));
142 EXPECT_THAT(retry_window.delay_max(), EqualsProto(retry_window.delay_min()));
143 }
144
GetAcceptedRetryWindow()145 google::internal::federatedml::v2::RetryWindow GetAcceptedRetryWindow() {
146 google::internal::federatedml::v2::RetryWindow retry_window;
147 // Must not overlap with kTransientErrorsRetryPeriodSecs or
148 // kPermanentErrorsRetryPeriodSecs.
149 retry_window.mutable_delay_min()->set_seconds(200L);
150 retry_window.mutable_delay_max()->set_seconds(299L);
151 *retry_window.mutable_retry_token() = "RETRY_TOKEN_ACCEPTED";
152 return retry_window;
153 }
154
GetRejectedRetryWindow()155 google::internal::federatedml::v2::RetryWindow GetRejectedRetryWindow() {
156 google::internal::federatedml::v2::RetryWindow retry_window;
157 // Must not overlap with kTransientErrorsRetryPeriodSecs or
158 // kPermanentErrorsRetryPeriodSecs.
159 retry_window.mutable_delay_min()->set_seconds(300);
160 retry_window.mutable_delay_max()->set_seconds(399L);
161 *retry_window.mutable_retry_token() = "RETRY_TOKEN_REJECTED";
162 return retry_window;
163 }
164
ExpectAcceptedRetryWindow(const RetryWindow & retry_window)165 void ExpectAcceptedRetryWindow(const RetryWindow& retry_window) {
166 // The calculated retry delay must lie within the expected permanent errors
167 // retry delay range.
168 EXPECT_THAT(retry_window.delay_min().seconds() +
169 retry_window.delay_min().nanos() / 1000000000,
170 AllOf(Ge(200), Lt(299L)));
171 EXPECT_THAT(retry_window.delay_max(), EqualsProto(retry_window.delay_min()));
172 }
173
ExpectRejectedRetryWindow(const RetryWindow & retry_window)174 void ExpectRejectedRetryWindow(const RetryWindow& retry_window) {
175 // The calculated retry delay must lie within the expected permanent errors
176 // retry delay range.
177 EXPECT_THAT(retry_window.delay_min().seconds() +
178 retry_window.delay_min().nanos() / 1000000000,
179 AllOf(Ge(300), Lt(399)));
180 EXPECT_THAT(retry_window.delay_max(), EqualsProto(retry_window.delay_min()));
181 }
182
GetFakeCheckinRequestAck(const RetryWindow & accepted_retry_window=GetAcceptedRetryWindow (),const RetryWindow & rejected_retry_window=GetRejectedRetryWindow ())183 ServerStreamMessage GetFakeCheckinRequestAck(
184 const RetryWindow& accepted_retry_window = GetAcceptedRetryWindow(),
185 const RetryWindow& rejected_retry_window = GetRejectedRetryWindow()) {
186 ServerStreamMessage checkin_request_ack_message;
187 *checkin_request_ack_message.mutable_checkin_request_ack()
188 ->mutable_retry_window_if_accepted() = accepted_retry_window;
189 *checkin_request_ack_message.mutable_checkin_request_ack()
190 ->mutable_retry_window_if_rejected() = rejected_retry_window;
191 return checkin_request_ack_message;
192 }
193
GetFakeEnabledEligibilityCheckinResponse(const std::string & plan,const std::string & init_checkpoint,const std::string & execution_id)194 ServerStreamMessage GetFakeEnabledEligibilityCheckinResponse(
195 const std::string& plan, const std::string& init_checkpoint,
196 const std::string& execution_id) {
197 ServerStreamMessage checkin_response_message;
198 EligibilityEvalPayload* eval_payload =
199 checkin_response_message.mutable_eligibility_eval_checkin_response()
200 ->mutable_eligibility_eval_payload();
201 eval_payload->set_plan(plan);
202 eval_payload->set_init_checkpoint(init_checkpoint);
203 eval_payload->set_execution_id(execution_id);
204 return checkin_response_message;
205 }
206
GetFakeDisabledEligibilityCheckinResponse()207 ServerStreamMessage GetFakeDisabledEligibilityCheckinResponse() {
208 ServerStreamMessage checkin_response_message;
209 checkin_response_message.mutable_eligibility_eval_checkin_response()
210 ->mutable_no_eligibility_eval_configured();
211 return checkin_response_message;
212 }
213
GetFakeRejectedEligibilityCheckinResponse()214 ServerStreamMessage GetFakeRejectedEligibilityCheckinResponse() {
215 ServerStreamMessage rejection_response_message;
216 rejection_response_message.mutable_eligibility_eval_checkin_response()
217 ->mutable_rejection_info();
218 return rejection_response_message;
219 }
220
GetFakeTaskEligibilityInfo()221 TaskEligibilityInfo GetFakeTaskEligibilityInfo() {
222 TaskEligibilityInfo eligibility_info;
223 TaskWeight* task_weight = eligibility_info.mutable_task_weights()->Add();
224 task_weight->set_task_name("foo");
225 task_weight->set_weight(567.8);
226 return eligibility_info;
227 }
228
GetFakeRejectedCheckinResponse()229 ServerStreamMessage GetFakeRejectedCheckinResponse() {
230 ServerStreamMessage rejection_response_message;
231 rejection_response_message.mutable_checkin_response()
232 ->mutable_rejection_info();
233 return rejection_response_message;
234 }
235
GetFakeAcceptedCheckinResponse(const std::string & plan,const std::string & init_checkpoint,const std::string & federated_select_uri_template,const std::string & phase_id,bool use_secure_aggregation)236 ServerStreamMessage GetFakeAcceptedCheckinResponse(
237 const std::string& plan, const std::string& init_checkpoint,
238 const std::string& federated_select_uri_template,
239 const std::string& phase_id, bool use_secure_aggregation) {
240 ServerStreamMessage checkin_response_message;
241 AcceptanceInfo* acceptance_info =
242 checkin_response_message.mutable_checkin_response()
243 ->mutable_acceptance_info();
244 acceptance_info->set_plan(plan);
245 acceptance_info->set_execution_phase_id(phase_id);
246 acceptance_info->set_init_checkpoint(init_checkpoint);
247 acceptance_info->mutable_federated_select_uri_info()->set_uri_template(
248 federated_select_uri_template);
249 if (use_secure_aggregation) {
250 auto sec_agg =
251 acceptance_info->mutable_side_channel_protocol_execution_info()
252 ->mutable_secure_aggregation();
253 sec_agg->set_expected_number_of_clients(kSecAggExpectedNumberOfClients);
254 sec_agg->set_minimum_surviving_clients_for_reconstruction(
255 kSecAggMinSurvivingClientsForReconstruction);
256 sec_agg->set_minimum_clients_in_server_visible_aggregate(
257 kSecAggMinClientsInServerVisibleAggregate);
258 checkin_response_message.mutable_checkin_response()
259 ->mutable_protocol_options_response()
260 ->mutable_side_channels()
261 ->mutable_secure_aggregation()
262 ->set_client_variant(secagg::SECAGG_CLIENT_VARIANT_NATIVE_V1);
263 }
264 return checkin_response_message;
265 }
266
GetFakeReportResponse()267 ServerStreamMessage GetFakeReportResponse() {
268 ServerStreamMessage report_response_message;
269 *report_response_message.mutable_report_response() = ReportResponse();
270 return report_response_message;
271 }
272
GetExpectedEligibilityEvalCheckinRequest(bool enable_http_resource_support=false)273 ClientStreamMessage GetExpectedEligibilityEvalCheckinRequest(
274 bool enable_http_resource_support = false) {
275 ClientStreamMessage expected_message;
276 EligibilityEvalCheckinRequest* checkin_request =
277 expected_message.mutable_eligibility_eval_checkin_request();
278 checkin_request->set_population_name(kPopulationName);
279 checkin_request->set_client_version(kClientVersion);
280 checkin_request->set_retry_token(kRetryToken);
281 checkin_request->set_attestation_measurement(kAttestationMeasurement);
282 checkin_request->mutable_protocol_options_request()
283 ->mutable_side_channels()
284 ->mutable_secure_aggregation()
285 ->add_client_variant(secagg::SECAGG_CLIENT_VARIANT_NATIVE_V1);
286 checkin_request->mutable_protocol_options_request()->set_should_ack_checkin(
287 true);
288 checkin_request->mutable_protocol_options_request()
289 ->add_supported_http_compression_formats(
290 HttpCompressionFormat::HTTP_COMPRESSION_FORMAT_GZIP);
291
292 if (enable_http_resource_support) {
293 checkin_request->mutable_protocol_options_request()
294 ->set_supports_http_download(true);
295 checkin_request->mutable_protocol_options_request()
296 ->set_supports_eligibility_eval_http_download(true);
297 }
298
299 return expected_message;
300 }
301
302 // This returns the CheckinRequest gRPC proto we expect each Checkin(...) call
303 // to result in.
GetExpectedCheckinRequest(const std::optional<TaskEligibilityInfo> & task_eligibility_info=std::nullopt,bool enable_http_resource_support=false)304 ClientStreamMessage GetExpectedCheckinRequest(
305 const std::optional<TaskEligibilityInfo>& task_eligibility_info =
306 std::nullopt,
307 bool enable_http_resource_support = false) {
308 ClientStreamMessage expected_message;
309 CheckinRequest* checkin_request = expected_message.mutable_checkin_request();
310 checkin_request->set_population_name(kPopulationName);
311 checkin_request->set_client_version(kClientVersion);
312 checkin_request->set_retry_token(kRetryToken);
313 checkin_request->set_attestation_measurement(kAttestationMeasurement);
314 checkin_request->mutable_protocol_options_request()
315 ->mutable_side_channels()
316 ->mutable_secure_aggregation()
317 ->add_client_variant(secagg::SECAGG_CLIENT_VARIANT_NATIVE_V1);
318 checkin_request->mutable_protocol_options_request()->set_should_ack_checkin(
319 false);
320 checkin_request->mutable_protocol_options_request()
321 ->add_supported_http_compression_formats(
322 HttpCompressionFormat::HTTP_COMPRESSION_FORMAT_GZIP);
323
324 if (enable_http_resource_support) {
325 checkin_request->mutable_protocol_options_request()
326 ->set_supports_http_download(true);
327 checkin_request->mutable_protocol_options_request()
328 ->set_supports_eligibility_eval_http_download(true);
329 }
330
331 if (task_eligibility_info.has_value()) {
332 *checkin_request->mutable_task_eligibility_info() = *task_eligibility_info;
333 }
334 return expected_message;
335 }
336
337 class GrpcFederatedProtocolTest
338 // The first parameter indicates whether support for HTTP task resources
339 // should be enabled.
340 : public testing::TestWithParam<bool> {
341 public:
GrpcFederatedProtocolTest()342 GrpcFederatedProtocolTest() {
343 // The gRPC stream should always be closed at the end of all tests.
344 EXPECT_CALL(*mock_grpc_bidi_stream_, Close());
345 }
346
347 protected:
SetUp()348 void SetUp() override {
349 enable_http_resource_support_ = GetParam();
350 EXPECT_CALL(*mock_grpc_bidi_stream_, ChunkingLayerBytesReceived())
351 .WillRepeatedly(Return(0));
352 EXPECT_CALL(*mock_grpc_bidi_stream_, ChunkingLayerBytesSent())
353 .WillRepeatedly(Return(0));
354 EXPECT_CALL(mock_flags_,
355 federated_training_transient_errors_retry_delay_secs)
356 .WillRepeatedly(Return(kTransientErrorsRetryPeriodSecs));
357 EXPECT_CALL(mock_flags_,
358 federated_training_transient_errors_retry_delay_jitter_percent)
359 .WillRepeatedly(Return(kTransientErrorsRetryDelayJitterPercent));
360 EXPECT_CALL(mock_flags_,
361 federated_training_permanent_errors_retry_delay_secs)
362 .WillRepeatedly(Return(kPermanentErrorsRetryPeriodSecs));
363 EXPECT_CALL(mock_flags_,
364 federated_training_permanent_errors_retry_delay_jitter_percent)
365 .WillRepeatedly(Return(kPermanentErrorsRetryDelayJitterPercent));
366 EXPECT_CALL(mock_flags_, federated_training_permanent_error_codes)
367 .WillRepeatedly(Return(std::vector<int32_t>{
368 static_cast<int32_t>(absl::StatusCode::kNotFound),
369 static_cast<int32_t>(absl::StatusCode::kInvalidArgument),
370 static_cast<int32_t>(absl::StatusCode::kUnimplemented)}));
371 EXPECT_CALL(mock_flags_,
372 enable_grpc_with_eligibility_eval_http_resource_support)
373 .WillRepeatedly(Return(enable_http_resource_support_));
374
375 // We only initialize federated_protocol_ in this SetUp method, rather than
376 // in the test's constructor, to ensure that we can set mock flag values
377 // before the GrpcFederatedProtocol constructor is called. Using
378 // std::unique_ptr conveniently allows us to assign the field a new value
379 // after construction (which we could not do if the field's type was
380 // GrpcFederatedProtocol, since it doesn't have copy or move constructors).
381 federated_protocol_ = std::make_unique<GrpcFederatedProtocol>(
382 &mock_event_publisher_, &mock_log_manager_,
383 absl::WrapUnique(mock_secagg_runner_factory_), &mock_flags_,
384 /*http_client=*/
385 enable_http_resource_support_ ? &mock_http_client_ : nullptr,
386 // We want to inject mocks stored in unique_ptrs to the
387 // class-under-test, hence we transfer ownership via WrapUnique. To
388 // write expectations for the mock, we retain the raw pointer to it,
389 // which will be valid until GrpcFederatedProtocol's d'tor is called.
390 absl::WrapUnique(mock_grpc_bidi_stream_), kPopulationName, kRetryToken,
391 kClientVersion, kAttestationMeasurement,
392 mock_should_abort_.AsStdFunction(), absl::BitGen(),
393 InterruptibleRunner::TimingConfig{
394 .polling_period = absl::ZeroDuration(),
395 .graceful_shutdown_period = absl::InfiniteDuration(),
396 .extended_shutdown_period = absl::InfiniteDuration()},
397 &mock_resource_cache_);
398 }
399
TearDown()400 void TearDown() override {
401 fcp::client::http::HttpRequestHandle::SentReceivedBytes
402 sent_received_bytes = mock_http_client_.TotalSentReceivedBytes();
403
404 NetworkStats network_stats = federated_protocol_->GetNetworkStats();
405 EXPECT_THAT(network_stats.bytes_downloaded,
406 Ge(mock_grpc_bidi_stream_->ChunkingLayerBytesReceived() +
407 sent_received_bytes.received_bytes));
408 EXPECT_THAT(network_stats.bytes_uploaded,
409 Ge(mock_grpc_bidi_stream_->ChunkingLayerBytesSent() +
410 sent_received_bytes.sent_bytes));
411 // If any network traffic occurred, we expect to see some time reflected in
412 // the duration (if the flag is on).
413 if (network_stats.bytes_uploaded > 0) {
414 EXPECT_THAT(network_stats.network_duration, Gt(absl::ZeroDuration()));
415 }
416 }
417
418 // This function runs a successful
419 // EligibilityEvalCheckin(mock_eet_received_callback_.AsStdFunction()) that
420 // results in an eligibility eval payload being returned by the server. This
421 // is a utility function used by Checkin*() tests that depend on a prior,
422 // successful execution of
423 // EligibilityEvalCheckin(mock_eet_received_callback_.AsStdFunction()). It
424 // returns a absl::Status, which the caller should verify is OK using
425 // ASSERT_OK.
RunSuccessfulEligibilityEvalCheckin(bool eligibility_eval_enabled=true,const RetryWindow & accepted_retry_window=GetAcceptedRetryWindow (),const RetryWindow & rejected_retry_window=GetRejectedRetryWindow ())426 absl::Status RunSuccessfulEligibilityEvalCheckin(
427 bool eligibility_eval_enabled = true,
428 const RetryWindow& accepted_retry_window = GetAcceptedRetryWindow(),
429 const RetryWindow& rejected_retry_window = GetRejectedRetryWindow()) {
430 EXPECT_CALL(
431 *mock_grpc_bidi_stream_,
432 Send(Pointee(EqualsProto(GetExpectedEligibilityEvalCheckinRequest(
433 enable_http_resource_support_)))))
434 .WillOnce(Return(absl::OkStatus()));
435
436 const std::string expected_execution_id = "ELIGIBILITY_EVAL_EXECUTION_ID";
437 EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
438 .WillOnce(DoAll(SetArgPointee<0>(GetFakeCheckinRequestAck(
439 accepted_retry_window, rejected_retry_window)),
440 Return(absl::OkStatus())))
441 .WillOnce(
442 DoAll(SetArgPointee<0>(
443 eligibility_eval_enabled
444 ? GetFakeEnabledEligibilityCheckinResponse(
445 kPlan, kInitCheckpoint, expected_execution_id)
446 : GetFakeDisabledEligibilityCheckinResponse()),
447 Return(absl::OkStatus())));
448
449 return federated_protocol_
450 ->EligibilityEvalCheckin(mock_eet_received_callback_.AsStdFunction())
451 .status();
452 }
453
454 // This function runs a successful Checkin() that results in acceptance by the
455 // server. This is a utility function used by Report*() tests that depend on a
456 // prior, successful execution of Checkin().
457 // It returns a absl::Status, which the caller should verify is OK using
458 // ASSERT_OK.
RunSuccessfulCheckin(bool use_secure_aggregation,const std::optional<TaskEligibilityInfo> & task_eligibility_info=GetFakeTaskEligibilityInfo ())459 absl::StatusOr<FederatedProtocol::CheckinResult> RunSuccessfulCheckin(
460 bool use_secure_aggregation,
461 const std::optional<TaskEligibilityInfo>& task_eligibility_info =
462 GetFakeTaskEligibilityInfo()) {
463 EXPECT_CALL(*mock_grpc_bidi_stream_,
464 Send(Pointee(EqualsProto(GetExpectedCheckinRequest(
465 task_eligibility_info, enable_http_resource_support_)))))
466 .WillOnce(Return(absl::OkStatus()));
467
468 {
469 InSequence seq;
470 EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
471 .WillOnce(
472 DoAll(SetArgPointee<0>(GetFakeAcceptedCheckinResponse(
473 kPlan, kInitCheckpoint, kFederatedSelectUriTemplate,
474 kExecutionPhaseId, use_secure_aggregation)),
475 Return(absl::OkStatus())))
476 .RetiresOnSaturation();
477 }
478
479 return federated_protocol_->Checkin(
480 task_eligibility_info, mock_task_received_callback_.AsStdFunction());
481 }
482
483 // See note in the constructor for why these are pointers.
484 StrictMock<MockGrpcBidiStream>* mock_grpc_bidi_stream_ =
485 new StrictMock<MockGrpcBidiStream>();
486
487 StrictMock<MockEventPublisher> mock_event_publisher_;
488 NiceMock<MockLogManager> mock_log_manager_;
489 StrictMock<MockSecAggRunnerFactory>* mock_secagg_runner_factory_ =
490 new StrictMock<MockSecAggRunnerFactory>();
491 StrictMock<MockSecAggRunner>* mock_secagg_runner_;
492 NiceMock<MockFlags> mock_flags_;
493 StrictMock<MockHttpClient> mock_http_client_;
494 NiceMock<MockFunction<bool()>> mock_should_abort_;
495 StrictMock<cache::MockResourceCache> mock_resource_cache_;
496 NiceMock<MockFunction<void(
497 const ::fcp::client::FederatedProtocol::EligibilityEvalTask&)>>
498 mock_eet_received_callback_;
499 NiceMock<MockFunction<void(
500 const ::fcp::client::FederatedProtocol::TaskAssignment&)>>
501 mock_task_received_callback_;
502
503 // The class under test.
504 std::unique_ptr<GrpcFederatedProtocol> federated_protocol_;
505 bool enable_http_resource_support_;
506 };
507
GenerateTestName(const testing::TestParamInfo<GrpcFederatedProtocolTest::ParamType> & info)508 std::string GenerateTestName(
509 const testing::TestParamInfo<GrpcFederatedProtocolTest::ParamType>& info) {
510 std::string name = info.param ? "Http_resource_support_enabled"
511 : "Http_resource_support_disabled";
512 return name;
513 }
514
515 INSTANTIATE_TEST_SUITE_P(NewVsOldBehavior, GrpcFederatedProtocolTest,
516 testing::Bool(), GenerateTestName);
517
518 using GrpcFederatedProtocolDeathTest = GrpcFederatedProtocolTest;
519 INSTANTIATE_TEST_SUITE_P(NewVsOldBehavior, GrpcFederatedProtocolDeathTest,
520 testing::Bool(), GenerateTestName);
521
TEST_P(GrpcFederatedProtocolTest,TestTransientErrorRetryWindowDifferentAcrossDifferentInstances)522 TEST_P(GrpcFederatedProtocolTest,
523 TestTransientErrorRetryWindowDifferentAcrossDifferentInstances) {
524 const RetryWindow& retry_window1 =
525 federated_protocol_->GetLatestRetryWindow();
526 ExpectTransientErrorRetryWindow(retry_window1);
527 federated_protocol_.reset(nullptr);
528
529 mock_grpc_bidi_stream_ = new StrictMock<MockGrpcBidiStream>();
530 EXPECT_CALL(*mock_grpc_bidi_stream_, ChunkingLayerBytesReceived())
531 .WillRepeatedly(Return(0));
532 EXPECT_CALL(*mock_grpc_bidi_stream_, ChunkingLayerBytesSent())
533 .WillRepeatedly(Return(0));
534 EXPECT_CALL(*mock_grpc_bidi_stream_, Close());
535 mock_secagg_runner_factory_ = new StrictMock<MockSecAggRunnerFactory>();
536 // Create a new GrpcFederatedProtocol instance. It should not produce the same
537 // retry window value as the one we just got. This is a simple correctness
538 // check to ensure that the value is at least randomly generated (and that we
539 // don't accidentally use the random number generator incorrectly).
540 federated_protocol_ = std::make_unique<GrpcFederatedProtocol>(
541 &mock_event_publisher_, &mock_log_manager_,
542 absl::WrapUnique(mock_secagg_runner_factory_), &mock_flags_,
543 /*http_client=*/nullptr, absl::WrapUnique(mock_grpc_bidi_stream_),
544 kPopulationName, kRetryToken, kClientVersion, kAttestationMeasurement,
545 mock_should_abort_.AsStdFunction(), absl::BitGen(),
546 InterruptibleRunner::TimingConfig{
547 .polling_period = absl::ZeroDuration(),
548 .graceful_shutdown_period = absl::InfiniteDuration(),
549 .extended_shutdown_period = absl::InfiniteDuration()},
550 &mock_resource_cache_);
551
552 const RetryWindow& retry_window2 =
553 federated_protocol_->GetLatestRetryWindow();
554 ExpectTransientErrorRetryWindow(retry_window2);
555
556 EXPECT_THAT(retry_window1, Not(EqualsProto(retry_window2)));
557 }
558
TEST_P(GrpcFederatedProtocolTest,TestEligibilityEvalCheckinSendFailsTransientError)559 TEST_P(GrpcFederatedProtocolTest,
560 TestEligibilityEvalCheckinSendFailsTransientError) {
561 // Make the gRPC stream return an UNAVAILABLE error when the
562 // EligibilityEvalCheckin(...) code tries to send its first message. This
563 // should result in the error being returned as the result.
564 EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
565 .WillOnce(Return(absl::UnavailableError("foo")));
566
567 auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
568 mock_eet_received_callback_.AsStdFunction());
569
570 EXPECT_THAT(eligibility_checkin_result.status(), IsCode(UNAVAILABLE));
571 EXPECT_THAT(eligibility_checkin_result.status().message(), "foo");
572 // No RetryWindows were received from the server, so we expect to get a
573 // RetryWindow generated based on the transient errors retry delay flag.
574 ExpectTransientErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
575 }
576
TEST_P(GrpcFederatedProtocolTest,TestEligibilityEvalCheckinSendFailsPermanentError)577 TEST_P(GrpcFederatedProtocolTest,
578 TestEligibilityEvalCheckinSendFailsPermanentError) {
579 // Make the gRPC stream return an NOT_FOUND error when the
580 // EligibilityEvalCheckin(...) code tries to send its first message. This
581 // should result in the error being returned as the result.
582 EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
583 .WillOnce(Return(absl::NotFoundError("foo")));
584
585 auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
586 mock_eet_received_callback_.AsStdFunction());
587
588 EXPECT_THAT(eligibility_checkin_result.status(), IsCode(NOT_FOUND));
589 EXPECT_THAT(eligibility_checkin_result.status().message(), "foo");
590 // No RetryWindows were received from the server, so we expect to get a
591 // RetryWindow generated based on the *permanent* errors retry delay flag,
592 // since NOT_FOUND is marked as a permanent error in the flags.
593 ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
594 }
595
596 // Tests the case where the blocking Send() call in EligibilityEvalCheckin is
597 // interrupted.
TEST_P(GrpcFederatedProtocolTest,TestEligibilityEvalCheckinSendInterrupted)598 TEST_P(GrpcFederatedProtocolTest, TestEligibilityEvalCheckinSendInterrupted) {
599 absl::BlockingCounter counter_should_abort(1);
600
601 // Make Send() block until the counter is decremented.
602 EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
603 .WillOnce([&counter_should_abort](ClientStreamMessage* ignored) {
604 counter_should_abort.Wait();
605 return absl::OkStatus();
606 });
607 // Make should_abort return false for the first two calls, and then make it
608 // decrement the counter and return true, triggering an abort sequence and
609 // unblocking the Send() call we caused to block above.
610 EXPECT_CALL(mock_should_abort_, Call())
611 .WillOnce(Return(false))
612 .WillOnce(Return(false))
613 .WillRepeatedly([&counter_should_abort] {
614 counter_should_abort.DecrementCount();
615 return true;
616 });
617 // In addition to the Close() call we expect in the test fixture above, expect
618 // an additional one (the one that induced the abort).
619 EXPECT_CALL(*mock_grpc_bidi_stream_, Close()).Times(1).RetiresOnSaturation();
620 EXPECT_CALL(mock_log_manager_,
621 LogDiag(ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_GRPC));
622
623 auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
624 mock_eet_received_callback_.AsStdFunction());
625
626 EXPECT_THAT(eligibility_checkin_result.status(), IsCode(CANCELLED));
627 // No RetryWindows were received from the server, so we expect to get a
628 // RetryWindow generated based on the transient errors retry delay flag.
629 ExpectTransientErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
630 }
631
632 // If a CheckinRequestAck is requested in the ProtocolOptionsRequest but not
633 // received, UNIMPLEMENTED should be returned.
TEST_P(GrpcFederatedProtocolTest,TestEligibilityEvalCheckinMissingCheckinRequestAck)634 TEST_P(GrpcFederatedProtocolTest,
635 TestEligibilityEvalCheckinMissingCheckinRequestAck) {
636 // We immediately return an EligibilityEvalCheckinResponse, rather than
637 // returning a CheckinRequestAck first.
638 EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
639 .WillOnce(Return(absl::OkStatus()));
640 EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
641 .WillOnce(
642 DoAll(SetArgPointee<0>(GetFakeRejectedEligibilityCheckinResponse()),
643 Return(absl::OkStatus())));
644 EXPECT_CALL(
645 mock_log_manager_,
646 LogDiag(
647 ProdDiagCode::
648 BACKGROUND_TRAINING_CHECKIN_REQUEST_ACK_EXPECTED_BUT_NOT_RECVD)); // NOLINT
649
650 auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
651 mock_eet_received_callback_.AsStdFunction());
652
653 EXPECT_THAT(eligibility_checkin_result.status(), IsCode(UNIMPLEMENTED));
654 // No RetryWindows were received from the server, so we expect to get a
655 // RetryWindow generated based on the *permanent* errors retry delay flag,
656 // since UNIMPLEMENTED is marked as a permanent error in the flags.
657 ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
658 }
659
TEST_P(GrpcFederatedProtocolTest,TestEligibilityEvalCheckinWaitForCheckinRequestAckFails)660 TEST_P(GrpcFederatedProtocolTest,
661 TestEligibilityEvalCheckinWaitForCheckinRequestAckFails) {
662 EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
663 .WillOnce(Return(absl::OkStatus()));
664
665 // Make the very first Receive() call fail (i.e. the one expecting the
666 // CheckinRequestAck).
667 std::string expected_message = "foo";
668 EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
669 .WillOnce(Return(absl::AbortedError(expected_message)));
670
671 auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
672 mock_eet_received_callback_.AsStdFunction());
673
674 EXPECT_THAT(eligibility_checkin_result.status(), IsCode(ABORTED));
675 EXPECT_THAT(eligibility_checkin_result.status().message(), expected_message);
676 // No RetryWindows were received from the server, so we expect to get a
677 // RetryWindow generated based on the transient errors retry delay flag.
678 ExpectTransientErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
679 }
680
TEST_P(GrpcFederatedProtocolTest,TestEligibilityEvalCheckinWaitForCheckinResponseFails)681 TEST_P(GrpcFederatedProtocolTest,
682 TestEligibilityEvalCheckinWaitForCheckinResponseFails) {
683 EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
684 .WillOnce(Return(absl::OkStatus()));
685
686 // Failed checkins that have received an ack already should return the
687 // rejected retry window.
688 std::string expected_message = "foo";
689 EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
690 .WillOnce(DoAll(SetArgPointee<0>(GetFakeCheckinRequestAck()),
691 Return(absl::OkStatus())))
692 // Make the second Receive() call fail (i.e. the one expecting the
693 // EligibilityEvalCheckinResponse).
694 .WillOnce(Return(absl::AbortedError(expected_message)));
695
696 auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
697 mock_eet_received_callback_.AsStdFunction());
698
699 EXPECT_THAT(eligibility_checkin_result.status(), IsCode(ABORTED));
700 EXPECT_THAT(eligibility_checkin_result.status().message(), expected_message);
701 ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
702 }
703
TEST_P(GrpcFederatedProtocolTest,TestEligibilityEvalCheckinRejection)704 TEST_P(GrpcFederatedProtocolTest, TestEligibilityEvalCheckinRejection) {
705 EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
706 .WillOnce(Return(absl::OkStatus()));
707
708 EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
709 .WillOnce(DoAll(SetArgPointee<0>(GetFakeCheckinRequestAck()),
710 Return(absl::OkStatus())))
711 .WillOnce(
712 DoAll(SetArgPointee<0>(GetFakeRejectedEligibilityCheckinResponse()),
713 Return(absl::OkStatus())));
714
715 // The 'eet received' callback should not be invoked since no EET was given to
716 // the client.
717 EXPECT_CALL(mock_eet_received_callback_, Call(_)).Times(0);
718
719 auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
720 mock_eet_received_callback_.AsStdFunction());
721
722 ASSERT_OK(eligibility_checkin_result);
723 EXPECT_THAT(*eligibility_checkin_result,
724 VariantWith<FederatedProtocol::Rejection>(_));
725 ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
726 }
727
TEST_P(GrpcFederatedProtocolTest,TestEligibilityEvalCheckinDisabled)728 TEST_P(GrpcFederatedProtocolTest, TestEligibilityEvalCheckinDisabled) {
729 EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
730 .WillOnce(Return(absl::OkStatus()));
731
732 EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
733 .WillOnce(DoAll(SetArgPointee<0>(GetFakeCheckinRequestAck()),
734 Return(absl::OkStatus())))
735 .WillOnce(
736 DoAll(SetArgPointee<0>(GetFakeDisabledEligibilityCheckinResponse()),
737 Return(absl::OkStatus())));
738
739 // The 'eet received' callback should not be invoked since no EET was given to
740 // the client.
741 EXPECT_CALL(mock_eet_received_callback_, Call(_)).Times(0);
742
743 auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
744 mock_eet_received_callback_.AsStdFunction());
745
746 ASSERT_OK(eligibility_checkin_result);
747 EXPECT_THAT(*eligibility_checkin_result,
748 VariantWith<FederatedProtocol::EligibilityEvalDisabled>(_));
749 ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
750 }
751
TEST_P(GrpcFederatedProtocolTest,TestEligibilityEvalCheckinEnabled)752 TEST_P(GrpcFederatedProtocolTest, TestEligibilityEvalCheckinEnabled) {
753 // Note that in this particular test we check that the eligibility eval
754 // checkin request is as expected (in all prior tests we just use the '_'
755 // matcher, because the request isn't really relevant to the test).
756 EXPECT_CALL(*mock_grpc_bidi_stream_,
757 Send(Pointee(EqualsProto(GetExpectedEligibilityEvalCheckinRequest(
758 enable_http_resource_support_)))))
759 .WillOnce(Return(absl::OkStatus()));
760
761 // The EligibilityEvalCheckin(...) method should return the rejected
762 // RetryWindow, since after merely completing an eligibility eval checkin the
763 // client hasn't actually been accepted to a specific task yet.
764 std::string expected_plan = kPlan;
765 std::string expected_checkpoint = kInitCheckpoint;
766 std::string expected_execution_id = "ELIGIBILITY_EVAL_EXECUTION_ID";
767 EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
768 .WillOnce(DoAll(SetArgPointee<0>(GetFakeCheckinRequestAck()),
769 Return(absl::OkStatus())))
770 .WillOnce(
771 DoAll(SetArgPointee<0>(GetFakeEnabledEligibilityCheckinResponse(
772 expected_plan, expected_checkpoint, expected_execution_id)),
773 Return(absl::OkStatus())));
774 EXPECT_CALL(
775 mock_log_manager_,
776 LogDiag(ProdDiagCode::BACKGROUND_TRAINING_CHECKIN_REQUEST_ACK_RECEIVED));
777
778 // The 'EET received' callback should be called, even if the task resource
779 // data was available inline.
780 EXPECT_CALL(mock_eet_received_callback_,
781 Call(FieldsAre(FieldsAre("", ""), expected_execution_id,
782 Eq(std::nullopt))));
783
784 auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
785 mock_eet_received_callback_.AsStdFunction());
786
787 ASSERT_OK(eligibility_checkin_result);
788 // If HTTP support is enabled then the checkpoint data gets returned in the
789 // shape of an absl::Cord (rather than an std::string), regardless of
790 // whether it was actually downloaded via HTTP.
791 if (enable_http_resource_support_) {
792 EXPECT_THAT(*eligibility_checkin_result,
793 VariantWith<FederatedProtocol::EligibilityEvalTask>(
794 FieldsAre(FieldsAre(absl::Cord(expected_plan),
795 absl::Cord(expected_checkpoint)),
796 expected_execution_id, Eq(std::nullopt))));
797 } else {
798 EXPECT_THAT(*eligibility_checkin_result,
799 VariantWith<FederatedProtocol::EligibilityEvalTask>(
800 FieldsAre(FieldsAre(expected_plan, expected_checkpoint),
801 expected_execution_id, Eq(std::nullopt))));
802 }
803 ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
804 }
805
TEST_P(GrpcFederatedProtocolTest,TestEligiblityEvalCheckinEnabledWithHttpResourcesDownloaded)806 TEST_P(GrpcFederatedProtocolTest,
807 TestEligiblityEvalCheckinEnabledWithHttpResourcesDownloaded) {
808 if (!enable_http_resource_support_) {
809 GTEST_SKIP() << "This test only applies if the HTTP task resources feature "
810 "is enabled";
811 return;
812 }
813
814 EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
815 .WillOnce(Return(absl::OkStatus()));
816
817 std::string expected_plan = kPlan;
818 std::string plan_uri = "https://fake.uri/plan";
819 std::string expected_checkpoint = kInitCheckpoint;
820 std::string checkpoint_uri = "https://fake.uri/checkpoint";
821 std::string expected_execution_id = "ELIGIBILITY_EVAL_EXECUTION_ID";
822 ServerStreamMessage fake_response = GetFakeEnabledEligibilityCheckinResponse(
823 /*plan=*/"", /*init_checkpoint=*/"", expected_execution_id);
824 EligibilityEvalPayload* eligibility_eval_payload =
825 fake_response.mutable_eligibility_eval_checkin_response()
826 ->mutable_eligibility_eval_payload();
827 eligibility_eval_payload->mutable_plan_resource()->set_uri(plan_uri);
828 eligibility_eval_payload->mutable_init_checkpoint_resource()->set_uri(
829 checkpoint_uri);
830
831 EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
832 .WillOnce(DoAll(SetArgPointee<0>(GetFakeCheckinRequestAck()),
833 Return(absl::OkStatus())))
834 .WillOnce(
835 DoAll(SetArgPointee<0>(fake_response), Return(absl::OkStatus())));
836 EXPECT_CALL(
837 mock_log_manager_,
838 LogDiag(ProdDiagCode::BACKGROUND_TRAINING_CHECKIN_REQUEST_ACK_RECEIVED));
839
840 {
841 InSequence seq;
842 // The 'EET received' callback should be called *before* the actual task
843 // resources are fetched.
844 EXPECT_CALL(mock_eet_received_callback_,
845 Call(FieldsAre(FieldsAre("", ""), expected_execution_id,
846 Eq(std::nullopt))));
847
848 EXPECT_CALL(mock_http_client_,
849 PerformSingleRequest(SimpleHttpRequestMatcher(
850 plan_uri, HttpRequest::Method::kGet, _, "")))
851 .WillOnce(Return(FakeHttpResponse(200, {}, expected_plan)));
852
853 EXPECT_CALL(mock_http_client_,
854 PerformSingleRequest(SimpleHttpRequestMatcher(
855 checkpoint_uri, HttpRequest::Method::kGet, _, "")))
856 .WillOnce(Return(FakeHttpResponse(200, {}, expected_checkpoint)));
857 }
858
859 {
860 InSequence seq;
861 EXPECT_CALL(
862 mock_log_manager_,
863 LogDiag(
864 ProdDiagCode::HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_USES_HTTP));
865 EXPECT_CALL(
866 mock_log_manager_,
867 LogDiag(
868 ProdDiagCode::
869 HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_SUCCEEDED));
870 }
871
872 // Issue the Eligibility Eval checkin.
873 auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
874 mock_eet_received_callback_.AsStdFunction());
875
876 ASSERT_OK(eligibility_checkin_result);
877 EXPECT_THAT(
878 *eligibility_checkin_result,
879 VariantWith<FederatedProtocol::EligibilityEvalTask>(FieldsAre(
880 FieldsAre(absl::Cord(expected_plan), absl::Cord(expected_checkpoint)),
881 expected_execution_id, Eq(std::nullopt))));
882 }
883
TEST_P(GrpcFederatedProtocolTest,TestEligiblityEvalCheckinEnabledWithHttpResourcesPlanDataFetchFailed)884 TEST_P(GrpcFederatedProtocolTest,
885 TestEligiblityEvalCheckinEnabledWithHttpResourcesPlanDataFetchFailed) {
886 if (!enable_http_resource_support_) {
887 GTEST_SKIP() << "This test only applies if the HTTP task resources feature "
888 "is enabled";
889 return;
890 }
891
892 EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
893 .WillOnce(Return(absl::OkStatus()));
894
895 std::string expected_plan = kPlan;
896 std::string plan_uri = "https://fake.uri/plan";
897 std::string expected_checkpoint = kInitCheckpoint;
898 std::string expected_execution_id = "ELIGIBILITY_EVAL_EXECUTION_ID";
899 ServerStreamMessage fake_response = GetFakeEnabledEligibilityCheckinResponse(
900 /*plan=*/"", expected_checkpoint, expected_execution_id);
901 fake_response.mutable_eligibility_eval_checkin_response()
902 ->mutable_eligibility_eval_payload()
903 ->mutable_plan_resource()
904 ->set_uri(plan_uri);
905 EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
906 .WillOnce(DoAll(SetArgPointee<0>(GetFakeCheckinRequestAck()),
907 Return(absl::OkStatus())))
908 .WillOnce(
909 DoAll(SetArgPointee<0>(fake_response), Return(absl::OkStatus())));
910 EXPECT_CALL(
911 mock_log_manager_,
912 LogDiag(ProdDiagCode::BACKGROUND_TRAINING_CHECKIN_REQUEST_ACK_RECEIVED));
913
914 EXPECT_CALL(mock_http_client_,
915 PerformSingleRequest(SimpleHttpRequestMatcher(
916 plan_uri, HttpRequest::Method::kGet, _, "")))
917 .WillOnce(Return(FakeHttpResponse(404, {}, "")));
918
919 {
920 InSequence seq;
921 EXPECT_CALL(
922 mock_log_manager_,
923 LogDiag(
924 ProdDiagCode::HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_USES_HTTP));
925 EXPECT_CALL(
926 mock_log_manager_,
927 LogDiag(
928 ProdDiagCode::
929 HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_FAILED));
930 }
931
932 // Issue the eligibility eval checkin.
933 auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
934 mock_eet_received_callback_.AsStdFunction());
935
936 EXPECT_THAT(eligibility_checkin_result.status(), IsCode(NOT_FOUND));
937 EXPECT_THAT(eligibility_checkin_result.status().message(),
938 HasSubstr("plan fetch failed"));
939 EXPECT_THAT(eligibility_checkin_result.status().message(), HasSubstr("404"));
940 // The EligibilityEvalCheckin call is expected to return the permanent error
941 // retry window, since 404 maps to a permanent error.
942 ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
943 }
944
TEST_P(GrpcFederatedProtocolTest,TestEligiblityEvalCheckinEnabledWithHttpResourcesCheckpointFetchFailed)945 TEST_P(GrpcFederatedProtocolTest,
946 TestEligiblityEvalCheckinEnabledWithHttpResourcesCheckpointFetchFailed) {
947 if (!enable_http_resource_support_) {
948 GTEST_SKIP() << "This test only applies if the HTTP task resources feature "
949 "is enabled";
950 return;
951 }
952
953 EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
954 .WillOnce(Return(absl::OkStatus()));
955
956 std::string expected_plan = kPlan;
957 std::string expected_checkpoint = kInitCheckpoint;
958 std::string checkpoint_uri = "https://fake.uri/checkpoint";
959 std::string expected_execution_id = "ELIGIBILITY_EVAL_EXECUTION_ID";
960 ServerStreamMessage fake_response = GetFakeEnabledEligibilityCheckinResponse(
961 expected_plan, /*init_checkpoint=*/"", expected_execution_id);
962 fake_response.mutable_eligibility_eval_checkin_response()
963 ->mutable_eligibility_eval_payload()
964 ->mutable_init_checkpoint_resource()
965 ->set_uri(checkpoint_uri);
966 EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
967 .WillOnce(DoAll(SetArgPointee<0>(GetFakeCheckinRequestAck()),
968 Return(absl::OkStatus())))
969 .WillOnce(
970 DoAll(SetArgPointee<0>(fake_response), Return(absl::OkStatus())));
971 EXPECT_CALL(
972 mock_log_manager_,
973 LogDiag(ProdDiagCode::BACKGROUND_TRAINING_CHECKIN_REQUEST_ACK_RECEIVED));
974
975 EXPECT_CALL(mock_http_client_,
976 PerformSingleRequest(SimpleHttpRequestMatcher(
977 checkpoint_uri, HttpRequest::Method::kGet, _, "")))
978 .WillOnce(Return(FakeHttpResponse(503, {}, "")));
979
980 {
981 InSequence seq;
982 EXPECT_CALL(
983 mock_log_manager_,
984 LogDiag(
985 ProdDiagCode::HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_USES_HTTP));
986 EXPECT_CALL(
987 mock_log_manager_,
988 LogDiag(
989 ProdDiagCode::
990 HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_FAILED));
991 }
992
993 // Issue the eligibility eval checkin.
994 auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
995 mock_eet_received_callback_.AsStdFunction());
996
997 EXPECT_THAT(eligibility_checkin_result.status(), IsCode(UNAVAILABLE));
998 EXPECT_THAT(eligibility_checkin_result.status().message(),
999 HasSubstr("checkpoint fetch failed"));
1000 EXPECT_THAT(eligibility_checkin_result.status().message(), HasSubstr("503"));
1001 // The EligibilityEvalCheckin call is expected to return the rejected error
1002 // retry window, since 503 maps to a transient error.
1003 ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1004 }
1005
1006 // Tests that the protocol correctly sanitizes any invalid values it may have
1007 // received from the server.
TEST_P(GrpcFederatedProtocolTest,TestNegativeMinMaxRetryDelayValueSanitization)1008 TEST_P(GrpcFederatedProtocolTest,
1009 TestNegativeMinMaxRetryDelayValueSanitization) {
1010 google::internal::federatedml::v2::RetryWindow retry_window;
1011 retry_window.mutable_delay_min()->set_seconds(-1);
1012 retry_window.mutable_delay_max()->set_seconds(-2);
1013
1014 // The above retry window's negative min/max values should be clamped to 0.
1015 google::internal::federatedml::v2::RetryWindow expected_retry_window;
1016 expected_retry_window.mutable_delay_min()->set_seconds(0);
1017 expected_retry_window.mutable_delay_max()->set_seconds(0);
1018
1019 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin(
1020 /* eligibility_eval_enabled=*/true, retry_window, retry_window));
1021 const RetryWindow& actual_retry_window =
1022 federated_protocol_->GetLatestRetryWindow();
1023 // The above retry window's invalid max value should be clamped to the min
1024 // value (minus some errors introduced by the inaccuracy of double
1025 // multiplication).
1026 EXPECT_THAT(actual_retry_window.delay_min().seconds() +
1027 actual_retry_window.delay_min().nanos() / 1000000000.0,
1028 DoubleEq(0));
1029 EXPECT_THAT(actual_retry_window.delay_max().seconds() +
1030 actual_retry_window.delay_max().nanos() / 1000000000.0,
1031 DoubleEq(0));
1032 }
1033
1034 // Tests that the protocol correctly sanitizes any invalid values it may have
1035 // received from the server.
TEST_P(GrpcFederatedProtocolTest,TestInvalidMaxRetryDelayValueSanitization)1036 TEST_P(GrpcFederatedProtocolTest, TestInvalidMaxRetryDelayValueSanitization) {
1037 google::internal::federatedml::v2::RetryWindow retry_window;
1038 retry_window.mutable_delay_min()->set_seconds(1234);
1039 retry_window.mutable_delay_max()->set_seconds(1233); // less than delay_min
1040
1041 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin(
1042 /* eligibility_eval_enabled=*/true, retry_window, retry_window));
1043 const RetryWindow& actual_retry_window =
1044 federated_protocol_->GetLatestRetryWindow();
1045 // The above retry window's invalid max value should be clamped to the min
1046 // value (minus some errors introduced by the inaccuracy of double
1047 // multiplication). Note that DoubleEq enforces too precise of bounds, so we
1048 // use DoubleNear instead.
1049 EXPECT_THAT(actual_retry_window.delay_min().seconds() +
1050 actual_retry_window.delay_min().nanos() / 1000000000.0,
1051 DoubleNear(1234.0, 0.02));
1052 EXPECT_THAT(actual_retry_window.delay_max().seconds() +
1053 actual_retry_window.delay_max().nanos() / 1000000000.0,
1054 DoubleNear(1234.0, 0.02));
1055 }
1056
TEST_P(GrpcFederatedProtocolDeathTest,TestCheckinMissingTaskEligibilityInfo)1057 TEST_P(GrpcFederatedProtocolDeathTest, TestCheckinMissingTaskEligibilityInfo) {
1058 // Issue an eligibility eval checkin first.
1059 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1060
1061 // A Checkin(...) request with a missing TaskEligibilityInfo should now fail,
1062 // as the protocol requires us to provide one based on the plan includes in
1063 // the eligibility eval checkin response payload.
1064 ASSERT_DEATH(
1065 {
1066 auto unused = federated_protocol_->Checkin(
1067 std::nullopt, mock_task_received_callback_.AsStdFunction());
1068 },
1069 _);
1070 }
1071
TEST_P(GrpcFederatedProtocolTest,TestCheckinSendFailsTransientError)1072 TEST_P(GrpcFederatedProtocolTest, TestCheckinSendFailsTransientError) {
1073 // Issue an eligibility eval checkin first.
1074 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1075
1076 // Make the gRPC stream return an UNAVAILABLE error when the Checkin(...) code
1077 // tries to send its first message. This should result in the error being
1078 // returned as the result.
1079 EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
1080 .WillOnce(Return(absl::UnavailableError("foo")));
1081
1082 auto checkin_result = federated_protocol_->Checkin(
1083 GetFakeTaskEligibilityInfo(),
1084 mock_task_received_callback_.AsStdFunction());
1085 EXPECT_THAT(checkin_result.status(), IsCode(UNAVAILABLE));
1086 EXPECT_THAT(checkin_result.status().message(), "foo");
1087 // RetryWindows were already received from the server during the eligibility
1088 // eval checkin, so we expect to get a 'rejected' retry window.
1089 ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1090 }
1091
TEST_P(GrpcFederatedProtocolTest,TestCheckinSendFailsPermanentError)1092 TEST_P(GrpcFederatedProtocolTest, TestCheckinSendFailsPermanentError) {
1093 // Issue an eligibility eval checkin first.
1094 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1095
1096 // Make the gRPC stream return an NOT_FOUND error when the Checkin(...) code
1097 // tries to send its first message. This should result in the error being
1098 // returned as the result.
1099 EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
1100 .WillOnce(Return(absl::NotFoundError("foo")));
1101
1102 auto checkin_result = federated_protocol_->Checkin(
1103 GetFakeTaskEligibilityInfo(),
1104 mock_task_received_callback_.AsStdFunction());
1105 EXPECT_THAT(checkin_result.status(), IsCode(NOT_FOUND));
1106 EXPECT_THAT(checkin_result.status().message(), "foo");
1107 // Even though RetryWindows were already received from the server during the
1108 // eligibility eval checkin, we expect a RetryWindow generated based on the
1109 // *permanent* errors retry delay flag, since NOT_FOUND is marked as a
1110 // permanent error in the flags, and permanent errors should always result in
1111 // permanent error windows (regardless of whether a CheckinRequestAck was
1112 // already received).
1113 ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
1114 }
1115
1116 // Tests the case where the blocking Send() call in Checkin is interrupted.
TEST_P(GrpcFederatedProtocolTest,TestCheckinSendInterrupted)1117 TEST_P(GrpcFederatedProtocolTest, TestCheckinSendInterrupted) {
1118 // Issue an eligibility eval checkin first.
1119 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1120
1121 absl::BlockingCounter counter_should_abort(1);
1122
1123 // Make Send() block until the counter is decremented.
1124 EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
1125 .WillOnce([&counter_should_abort](ClientStreamMessage* ignored) {
1126 counter_should_abort.Wait();
1127 return absl::OkStatus();
1128 });
1129 // Make should_abort return false for the first two calls, and then make it
1130 // decrement the counter and return true, triggering an abort sequence and
1131 // unblocking the Send() call we caused to block above.
1132 EXPECT_CALL(mock_should_abort_, Call())
1133 .WillOnce(Return(false))
1134 .WillOnce(Return(false))
1135 .WillRepeatedly([&counter_should_abort] {
1136 counter_should_abort.DecrementCount();
1137 return true;
1138 });
1139 // In addition to the Close() call we expect in the test fixture above, expect
1140 // an additional one (the one that induced the abort).
1141 EXPECT_CALL(*mock_grpc_bidi_stream_, Close()).Times(1).RetiresOnSaturation();
1142 EXPECT_CALL(mock_log_manager_,
1143 LogDiag(ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_GRPC));
1144
1145 auto checkin_result = federated_protocol_->Checkin(
1146 GetFakeTaskEligibilityInfo(),
1147 mock_task_received_callback_.AsStdFunction());
1148 EXPECT_THAT(checkin_result.status(), IsCode(CANCELLED));
1149 // RetryWindows were already received from the server during the eligibility
1150 // eval checkin, so we expect to get a 'rejected' retry window.
1151 ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1152 }
1153
TEST_P(GrpcFederatedProtocolTest,TestCheckinRejectionWithTaskEligibilityInfo)1154 TEST_P(GrpcFederatedProtocolTest, TestCheckinRejectionWithTaskEligibilityInfo) {
1155 // Issue an eligibility eval checkin first.
1156 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1157
1158 // Expect a checkin request for the next call to Checkin(...).
1159 EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
1160 .WillOnce(Return(absl::OkStatus()));
1161 EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
1162 .WillOnce(DoAll(SetArgPointee<0>(GetFakeRejectedCheckinResponse()),
1163 Return(absl::OkStatus())));
1164
1165 // The 'task received' callback should not be invoked since no task was given
1166 // to the client.
1167 EXPECT_CALL(mock_task_received_callback_, Call(_)).Times(0);
1168
1169 // Issue the regular checkin.
1170 auto checkin_result = federated_protocol_->Checkin(
1171 GetFakeTaskEligibilityInfo(),
1172 mock_task_received_callback_.AsStdFunction());
1173
1174 ASSERT_OK(checkin_result.status());
1175 EXPECT_THAT(*checkin_result, VariantWith<FederatedProtocol::Rejection>(_));
1176 ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1177 }
1178
1179 // Tests whether we can issue a Checkin() request correctly without passing a
1180 // TaskEligibilityInfo, in the case that the eligibility eval checkin didn't
1181 // return any eligibility eval task to run.
TEST_P(GrpcFederatedProtocolTest,TestCheckinRejectionWithoutTaskEligibilityInfo)1182 TEST_P(GrpcFederatedProtocolTest,
1183 TestCheckinRejectionWithoutTaskEligibilityInfo) {
1184 // Issue an eligibility eval checkin first.
1185 ASSERT_OK(
1186 RunSuccessfulEligibilityEvalCheckin(/*eligibility_eval_enabled=*/false));
1187
1188 // Expect a checkin request for the next call to Checkin(...).
1189 EXPECT_CALL(*mock_grpc_bidi_stream_,
1190 Send(Pointee(EqualsProto(GetExpectedCheckinRequest(
1191 /*task_eligibility_info=*/std::nullopt,
1192 enable_http_resource_support_)))))
1193 .WillOnce(Return(absl::OkStatus()));
1194 EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
1195 .WillOnce(DoAll(SetArgPointee<0>(GetFakeRejectedCheckinResponse()),
1196 Return(absl::OkStatus())));
1197
1198 // The 'task received' callback should not be invoked since no task was given
1199 // to the client.
1200 EXPECT_CALL(mock_task_received_callback_, Call(_)).Times(0);
1201
1202 // Issue the regular checkin, without a TaskEligibilityInfo (since we didn't
1203 // receive an eligibility eval task to run during eligibility eval checkin).
1204 auto checkin_result = federated_protocol_->Checkin(
1205 std::nullopt, mock_task_received_callback_.AsStdFunction());
1206
1207 ASSERT_OK(checkin_result.status());
1208 EXPECT_THAT(*checkin_result, VariantWith<FederatedProtocol::Rejection>(_));
1209 ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1210 }
1211
TEST_P(GrpcFederatedProtocolTest,TestCheckinAccept)1212 TEST_P(GrpcFederatedProtocolTest, TestCheckinAccept) {
1213 // Issue an eligibility eval checkin first.
1214 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1215
1216 // Once the eligibility eval checkin has succeeded, let's fake some network
1217 // stats data so that we can verify that it is logged correctly.
1218 int64_t chunking_layer_bytes_downloaded = 555;
1219 int64_t chunking_layer_bytes_uploaded = 666;
1220 EXPECT_CALL(*mock_grpc_bidi_stream_, ChunkingLayerBytesReceived())
1221 .WillRepeatedly(Return(chunking_layer_bytes_downloaded));
1222 EXPECT_CALL(*mock_grpc_bidi_stream_, ChunkingLayerBytesSent())
1223 .WillRepeatedly(Return(chunking_layer_bytes_uploaded));
1224
1225 // Note that in this particular test we check that the CheckinRequest is as
1226 // expected (in all prior tests we just use the '_' matcher, because the
1227 // request isn't really relevant to the test).
1228 TaskEligibilityInfo expected_eligibility_info = GetFakeTaskEligibilityInfo();
1229 EXPECT_CALL(*mock_grpc_bidi_stream_,
1230 Send(Pointee(EqualsProto(GetExpectedCheckinRequest(
1231 expected_eligibility_info, enable_http_resource_support_)))))
1232 .WillOnce(Return(absl::OkStatus()));
1233
1234 std::string expected_plan = kPlan;
1235 std::string expected_checkpoint = kInitCheckpoint;
1236 EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
1237 .WillOnce(DoAll(SetArgPointee<0>(GetFakeAcceptedCheckinResponse(
1238 expected_plan, expected_checkpoint,
1239 kFederatedSelectUriTemplate, kExecutionPhaseId,
1240 /* use_secure_aggregation=*/true)),
1241 Return(absl::OkStatus())));
1242
1243 // The 'task received' callback should be called even when the resources were
1244 // available inline.
1245 EXPECT_CALL(
1246 mock_task_received_callback_,
1247 Call(FieldsAre(
1248 FieldsAre("", ""), kFederatedSelectUriTemplate, kExecutionPhaseId,
1249 Optional(AllOf(
1250 Field(&FederatedProtocol::SecAggInfo::expected_number_of_clients,
1251 kSecAggExpectedNumberOfClients),
1252 Field(&FederatedProtocol::SecAggInfo::
1253 minimum_clients_in_server_visible_aggregate,
1254 kSecAggMinClientsInServerVisibleAggregate))))));
1255
1256 // Issue the regular checkin.
1257 auto checkin_result = federated_protocol_->Checkin(
1258 expected_eligibility_info, mock_task_received_callback_.AsStdFunction());
1259
1260 ASSERT_OK(checkin_result.status());
1261 // If HTTP support is enabled then the checkpoint data gets returned in the
1262 // shape of an absl::Cord (rather than an std::string), regardless of whether
1263 // it was actually downloaded via HTTP.
1264 if (enable_http_resource_support_) {
1265 EXPECT_THAT(
1266 *checkin_result,
1267 VariantWith<FederatedProtocol::TaskAssignment>(FieldsAre(
1268 FieldsAre(absl::Cord(expected_plan),
1269 absl::Cord(expected_checkpoint)),
1270 kFederatedSelectUriTemplate, kExecutionPhaseId,
1271 Optional(AllOf(
1272 Field(
1273 &FederatedProtocol::SecAggInfo::expected_number_of_clients,
1274 kSecAggExpectedNumberOfClients),
1275 Field(&FederatedProtocol::SecAggInfo::
1276 minimum_clients_in_server_visible_aggregate,
1277 kSecAggMinClientsInServerVisibleAggregate))))));
1278 } else {
1279 EXPECT_THAT(
1280 *checkin_result,
1281 VariantWith<FederatedProtocol::TaskAssignment>(FieldsAre(
1282 FieldsAre(expected_plan, expected_checkpoint),
1283 kFederatedSelectUriTemplate, kExecutionPhaseId,
1284 Optional(AllOf(
1285 Field(
1286 &FederatedProtocol::SecAggInfo::expected_number_of_clients,
1287 kSecAggExpectedNumberOfClients),
1288 Field(&FederatedProtocol::SecAggInfo::
1289 minimum_clients_in_server_visible_aggregate,
1290 kSecAggMinClientsInServerVisibleAggregate))))));
1291 }
1292 // The Checkin call is expected to return the accepted retry window from the
1293 // CheckinRequestAck response to the first eligibility eval request.
1294 ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1295 }
1296
TEST_P(GrpcFederatedProtocolTest,TestCheckinAcceptWithHttpResourcesDownloaded)1297 TEST_P(GrpcFederatedProtocolTest,
1298 TestCheckinAcceptWithHttpResourcesDownloaded) {
1299 if (!enable_http_resource_support_) {
1300 GTEST_SKIP() << "This test only applies the HTTP task resources feature "
1301 "is enabled";
1302 return;
1303 }
1304 // Issue an eligibility eval checkin first.
1305 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1306
1307 // Once the eligibility eval checkin has succeeded, let's fake some network
1308 // stats data so that we can verify that it is logged correctly.
1309 int64_t chunking_layer_bytes_downloaded = 555;
1310 int64_t chunking_layer_bytes_uploaded = 666;
1311 EXPECT_CALL(*mock_grpc_bidi_stream_, ChunkingLayerBytesReceived())
1312 .WillRepeatedly(Return(chunking_layer_bytes_downloaded));
1313 EXPECT_CALL(*mock_grpc_bidi_stream_, ChunkingLayerBytesSent())
1314 .WillRepeatedly(Return(chunking_layer_bytes_uploaded));
1315
1316 // Note that in this particular test we check that the CheckinRequest is as
1317 // expected (in all prior tests we just use the '_' matcher, because the
1318 // request isn't really relevant to the test).
1319 TaskEligibilityInfo expected_eligibility_info = GetFakeTaskEligibilityInfo();
1320 EXPECT_CALL(
1321 *mock_grpc_bidi_stream_,
1322 Send(Pointee(EqualsProto(GetExpectedCheckinRequest(
1323 expected_eligibility_info, /*enable_http_resource_support=*/true)))))
1324 .WillOnce(Return(absl::OkStatus()));
1325
1326 std::string expected_plan = kPlan;
1327 std::string plan_uri = "https://fake.uri/plan";
1328 std::string expected_checkpoint = kInitCheckpoint;
1329 std::string checkpoint_uri = "https://fake.uri/checkpoint";
1330 ServerStreamMessage fake_checkin_response = GetFakeAcceptedCheckinResponse(
1331 /*plan=*/"", /*init_checkpoint=*/"", kFederatedSelectUriTemplate,
1332 kExecutionPhaseId,
1333 /* use_secure_aggregation=*/true);
1334 AcceptanceInfo* acceptance_info =
1335 fake_checkin_response.mutable_checkin_response()
1336 ->mutable_acceptance_info();
1337 acceptance_info->mutable_plan_resource()->set_uri(plan_uri);
1338 acceptance_info->mutable_init_checkpoint_resource()->set_uri(checkpoint_uri);
1339 EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
1340 .WillOnce(DoAll(SetArgPointee<0>(fake_checkin_response),
1341 Return(absl::OkStatus())));
1342
1343 {
1344 InSequence seq;
1345 // The 'task received' callback should be called *before* the actual task
1346 // resources are fetched.
1347 EXPECT_CALL(
1348 mock_task_received_callback_,
1349 Call(FieldsAre(
1350 FieldsAre("", ""), kFederatedSelectUriTemplate, kExecutionPhaseId,
1351 Optional(AllOf(
1352 Field(
1353 &FederatedProtocol::SecAggInfo::expected_number_of_clients,
1354 kSecAggExpectedNumberOfClients),
1355 Field(&FederatedProtocol::SecAggInfo::
1356 minimum_clients_in_server_visible_aggregate,
1357 kSecAggMinClientsInServerVisibleAggregate))))));
1358
1359 EXPECT_CALL(mock_http_client_,
1360 PerformSingleRequest(SimpleHttpRequestMatcher(
1361 plan_uri, HttpRequest::Method::kGet, _, "")))
1362 .WillOnce(Return(FakeHttpResponse(200, {}, expected_plan)));
1363
1364 EXPECT_CALL(mock_http_client_,
1365 PerformSingleRequest(SimpleHttpRequestMatcher(
1366 checkpoint_uri, HttpRequest::Method::kGet, _, "")))
1367 .WillOnce(Return(FakeHttpResponse(200, {}, expected_checkpoint)));
1368 }
1369
1370 {
1371 InSequence seq;
1372 EXPECT_CALL(
1373 mock_log_manager_,
1374 LogDiag(
1375 ProdDiagCode::HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_USES_HTTP));
1376 EXPECT_CALL(
1377 mock_log_manager_,
1378 LogDiag(
1379 ProdDiagCode::
1380 HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_SUCCEEDED));
1381 }
1382
1383 // Issue the regular checkin.
1384 auto checkin_result = federated_protocol_->Checkin(
1385 expected_eligibility_info, mock_task_received_callback_.AsStdFunction());
1386
1387 ASSERT_OK(checkin_result.status());
1388 EXPECT_THAT(
1389 *checkin_result,
1390 VariantWith<FederatedProtocol::TaskAssignment>(FieldsAre(
1391 FieldsAre(absl::Cord(expected_plan), absl::Cord(expected_checkpoint)),
1392 kFederatedSelectUriTemplate, kExecutionPhaseId,
1393 Optional(AllOf(
1394 Field(&FederatedProtocol::SecAggInfo::expected_number_of_clients,
1395 kSecAggExpectedNumberOfClients),
1396 Field(&FederatedProtocol::SecAggInfo::
1397 minimum_clients_in_server_visible_aggregate,
1398 kSecAggMinClientsInServerVisibleAggregate))))));
1399 // The Checkin call is expected to return the accepted retry window from the
1400 // CheckinRequestAck response to the first eligibility eval request.
1401 ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1402 }
1403
TEST_P(GrpcFederatedProtocolTest,TestCheckinAcceptWithHttpResourcePlanDataFetchFailed)1404 TEST_P(GrpcFederatedProtocolTest,
1405 TestCheckinAcceptWithHttpResourcePlanDataFetchFailed) {
1406 if (!enable_http_resource_support_) {
1407 GTEST_SKIP() << "This test only applies the HTTP task resources feature "
1408 "is enabled";
1409 return;
1410 }
1411 // Issue an eligibility eval checkin first.
1412 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1413
1414 // Note that in this particular test we check that the CheckinRequest is as
1415 // expected (in all prior tests we just use the '_' matcher, because the
1416 // request isn't really relevant to the test).
1417 TaskEligibilityInfo expected_eligibility_info = GetFakeTaskEligibilityInfo();
1418 EXPECT_CALL(
1419 *mock_grpc_bidi_stream_,
1420 Send(Pointee(EqualsProto(GetExpectedCheckinRequest(
1421 expected_eligibility_info, /*enable_http_resource_support=*/true)))))
1422 .WillOnce(Return(absl::OkStatus()));
1423
1424 std::string expected_plan = kPlan;
1425 std::string plan_uri = "https://fake.uri/plan";
1426 std::string expected_checkpoint = kInitCheckpoint;
1427 ServerStreamMessage fake_checkin_response = GetFakeAcceptedCheckinResponse(
1428 /*plan=*/"", expected_checkpoint, kFederatedSelectUriTemplate,
1429 kExecutionPhaseId,
1430 /* use_secure_aggregation=*/true);
1431 AcceptanceInfo* acceptance_info =
1432 fake_checkin_response.mutable_checkin_response()
1433 ->mutable_acceptance_info();
1434 acceptance_info->mutable_plan_resource()->set_uri(plan_uri);
1435 EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
1436 .WillOnce(DoAll(SetArgPointee<0>(fake_checkin_response),
1437 Return(absl::OkStatus())));
1438
1439 // Mock a failed plan fetch.
1440 EXPECT_CALL(mock_http_client_,
1441 PerformSingleRequest(SimpleHttpRequestMatcher(
1442 plan_uri, HttpRequest::Method::kGet, _, "")))
1443 .WillOnce(Return(FakeHttpResponse(404, {}, "")));
1444
1445 {
1446 InSequence seq;
1447 EXPECT_CALL(
1448 mock_log_manager_,
1449 LogDiag(
1450 ProdDiagCode::HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_USES_HTTP));
1451 EXPECT_CALL(
1452 mock_log_manager_,
1453 LogDiag(
1454 ProdDiagCode::
1455 HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_FAILED));
1456 }
1457
1458 // Issue the regular checkin.
1459 auto checkin_result = federated_protocol_->Checkin(
1460 expected_eligibility_info, mock_task_received_callback_.AsStdFunction());
1461
1462 EXPECT_THAT(checkin_result.status(), IsCode(NOT_FOUND));
1463 EXPECT_THAT(checkin_result.status().message(),
1464 HasSubstr("plan fetch failed"));
1465 EXPECT_THAT(checkin_result.status().message(), HasSubstr("404"));
1466 // The Checkin call is expected to return the permanent error retry window,
1467 // since 404 maps to a permanent error.
1468 ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
1469 }
1470
TEST_P(GrpcFederatedProtocolTest,TestCheckinAcceptWithHttpResourceCheckpointDataFetchFailed)1471 TEST_P(GrpcFederatedProtocolTest,
1472 TestCheckinAcceptWithHttpResourceCheckpointDataFetchFailed) {
1473 if (!enable_http_resource_support_) {
1474 GTEST_SKIP() << "This test only applies the HTTP task resources feature "
1475 "is enabled";
1476 return;
1477 }
1478 // Issue an eligibility eval checkin first.
1479 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1480
1481 // Note that in this particular test we check that the CheckinRequest is as
1482 // expected (in all prior tests we just use the '_' matcher, because the
1483 // request isn't really relevant to the test).
1484 TaskEligibilityInfo expected_eligibility_info = GetFakeTaskEligibilityInfo();
1485 EXPECT_CALL(
1486 *mock_grpc_bidi_stream_,
1487 Send(Pointee(EqualsProto(GetExpectedCheckinRequest(
1488 expected_eligibility_info, /*enable_http_resource_support=*/true)))))
1489 .WillOnce(Return(absl::OkStatus()));
1490
1491 std::string expected_plan = kPlan;
1492 std::string expected_checkpoint = kInitCheckpoint;
1493 std::string checkpoint_uri = "https://fake.uri/checkpoint";
1494 ServerStreamMessage fake_checkin_response = GetFakeAcceptedCheckinResponse(
1495 expected_plan, /*init_checkpoint=*/"", kFederatedSelectUriTemplate,
1496 kExecutionPhaseId,
1497 /* use_secure_aggregation=*/true);
1498 AcceptanceInfo* acceptance_info =
1499 fake_checkin_response.mutable_checkin_response()
1500 ->mutable_acceptance_info();
1501 acceptance_info->mutable_init_checkpoint_resource()->set_uri(checkpoint_uri);
1502 EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
1503 .WillOnce(DoAll(SetArgPointee<0>(fake_checkin_response),
1504 Return(absl::OkStatus())));
1505
1506 // Mock a failed checkpoint fetch.
1507 EXPECT_CALL(mock_http_client_,
1508 PerformSingleRequest(SimpleHttpRequestMatcher(
1509 checkpoint_uri, HttpRequest::Method::kGet, _, "")))
1510 .WillOnce(Return(FakeHttpResponse(503, {}, "")));
1511
1512 {
1513 InSequence seq;
1514 EXPECT_CALL(
1515 mock_log_manager_,
1516 LogDiag(
1517 ProdDiagCode::HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_USES_HTTP));
1518 EXPECT_CALL(
1519 mock_log_manager_,
1520 LogDiag(
1521 ProdDiagCode::
1522 HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_FAILED));
1523 }
1524
1525 // Issue the regular checkin.
1526 auto checkin_result = federated_protocol_->Checkin(
1527 expected_eligibility_info, mock_task_received_callback_.AsStdFunction());
1528
1529 EXPECT_THAT(checkin_result.status(), IsCode(UNAVAILABLE));
1530 EXPECT_THAT(checkin_result.status().message(),
1531 HasSubstr("checkpoint fetch failed"));
1532 EXPECT_THAT(checkin_result.status().message(), HasSubstr("503"));
1533 // The Checkin call is expected to return the rejected retry window from the
1534 // response to the first eligibility eval request.
1535 ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1536 }
1537
TEST_P(GrpcFederatedProtocolTest,TestCheckinAcceptNonSecAgg)1538 TEST_P(GrpcFederatedProtocolTest, TestCheckinAcceptNonSecAgg) {
1539 // Issue an eligibility eval checkin first, followed by a successful checkin
1540 // returning a task that doesn't use SecAgg.
1541 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1542 auto checkin_result = RunSuccessfulCheckin(/*use_secure_aggregation=*/false);
1543 ASSERT_OK(checkin_result.status());
1544 // If HTTP support is enabled then the checkpoint data gets returned in the
1545 // shape of an absl::Cord (rather than an std::string), regardless of whether
1546 // it was actually downloaded via HTTP.
1547 if (enable_http_resource_support_) {
1548 EXPECT_THAT(*checkin_result,
1549 VariantWith<FederatedProtocol::TaskAssignment>(FieldsAre(
1550 FieldsAre(absl::Cord(kPlan), absl::Cord(kInitCheckpoint)),
1551 kFederatedSelectUriTemplate, kExecutionPhaseId,
1552 // There should be no SecAggInfo in the result.
1553 Eq(std::nullopt))));
1554 } else {
1555 EXPECT_THAT(*checkin_result,
1556 VariantWith<FederatedProtocol::TaskAssignment>(
1557 FieldsAre(FieldsAre(kPlan, kInitCheckpoint),
1558 kFederatedSelectUriTemplate, kExecutionPhaseId,
1559 // There should be no SecAggInfo in the result.
1560 Eq(std::nullopt))));
1561 }
1562 }
1563
TEST_P(GrpcFederatedProtocolTest,TestReportWithSecAgg)1564 TEST_P(GrpcFederatedProtocolTest, TestReportWithSecAgg) {
1565 // Issue an eligibility eval checkin first, followed by a successful checkin.
1566 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1567 ASSERT_OK(RunSuccessfulCheckin(/*use_secure_aggregation=*/true));
1568 // Create a SecAgg like Checkpoint - a combination of a TF checkpoint, and
1569 // one or more SecAgg quantized aggregands.
1570 ComputationResults results;
1571 results.emplace("tensorflow_checkpoint", "");
1572 results.emplace("some_tensor", QuantizedTensor());
1573
1574 mock_secagg_runner_ = new StrictMock<MockSecAggRunner>();
1575 EXPECT_CALL(*mock_secagg_runner_factory_,
1576 CreateSecAggRunner(_, _, _, _, _, kSecAggExpectedNumberOfClients,
1577 kSecAggMinSurvivingClientsForReconstruction))
1578 .WillOnce(Return(ByMove(absl::WrapUnique(mock_secagg_runner_))));
1579 EXPECT_CALL(
1580 *mock_secagg_runner_,
1581 Run(UnorderedElementsAre(
1582 Pair("tensorflow_checkpoint", VariantWith<TFCheckpoint>(IsEmpty())),
1583 Pair("some_tensor", VariantWith<QuantizedTensor>(
1584 FieldsAre(IsEmpty(), 0, IsEmpty()))))))
1585 .WillOnce(Return(absl::OkStatus()));
1586 EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
1587 .WillOnce(
1588 DoAll(SetArgPointee<0>(GetFakeReportResponse()),
1589 Return(absl::OkStatus())));
1590 EXPECT_OK(federated_protocol_->ReportCompleted(
1591 std::move(results), absl::ZeroDuration(), std::nullopt));
1592 }
1593
TEST_P(GrpcFederatedProtocolTest,TestReportWithSecAggWithoutTFCheckpoint)1594 TEST_P(GrpcFederatedProtocolTest, TestReportWithSecAggWithoutTFCheckpoint) {
1595 // Issue an eligibility eval checkin first, followed by a successful checkin.
1596 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1597 ASSERT_OK(RunSuccessfulCheckin(/*use_secure_aggregation=*/true));
1598
1599 ComputationResults results;
1600 results.emplace("some_tensor", QuantizedTensor());
1601
1602 mock_secagg_runner_ = new StrictMock<MockSecAggRunner>();
1603 EXPECT_CALL(*mock_secagg_runner_factory_,
1604 CreateSecAggRunner(_, _, _, _, _, kSecAggExpectedNumberOfClients,
1605 kSecAggMinSurvivingClientsForReconstruction))
1606 .WillOnce(Return(ByMove(absl::WrapUnique(mock_secagg_runner_))));
1607 EXPECT_CALL(*mock_secagg_runner_,
1608 Run(UnorderedElementsAre(
1609 Pair("some_tensor", VariantWith<QuantizedTensor>(FieldsAre(
1610 IsEmpty(), 0, IsEmpty()))))))
1611 .WillOnce(Return(absl::OkStatus()));
1612 EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
1613 .WillOnce(
1614 DoAll(SetArgPointee<0>(GetFakeReportResponse()),
1615 Return(absl::OkStatus())));
1616 EXPECT_OK(federated_protocol_->ReportCompleted(
1617 std::move(results), absl::ZeroDuration(), std::nullopt));
1618 }
1619
1620 // This function tests the Report(...) method's Send code path, ensuring the
1621 // right events are logged / and the right data is transmitted to the server.
TEST_P(GrpcFederatedProtocolTest,TestReportSendFails)1622 TEST_P(GrpcFederatedProtocolTest, TestReportSendFails) {
1623 // Issue an eligibility eval checkin first, followed by a successful checkin.
1624 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1625 ASSERT_OK(RunSuccessfulCheckin(/*use_secure_aggregation=*/false));
1626
1627 // 1. Create input for the Report function.
1628 std::string checkpoint_str;
1629 const size_t kTFCheckpointSize = 32;
1630 checkpoint_str.resize(kTFCheckpointSize, 'X');
1631 ComputationResults results;
1632 results.emplace("tensorflow_checkpoint", checkpoint_str);
1633
1634 absl::Duration plan_duration = absl::Milliseconds(1337);
1635
1636 // 2. The expected message sent to the server by the ReportCompleted()
1637 // function, as text proto.
1638 ClientStreamMessage expected_client_stream_message;
1639 ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(
1640 absl::StrCat(
1641 "report_request {", " population_name: \"", kPopulationName, "\"",
1642 " execution_phase_id: \"", kExecutionPhaseId, "\"", " report {",
1643 " update_checkpoint: \"XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX\"",
1644 " serialized_train_event {", "[type.googleapis.com/",
1645 "google.internal.federatedml.v2.ClientExecutionStats] {",
1646 " duration { seconds: 1 nanos: 337000000 }", " }",
1647 " }", " }", "}"),
1648 &expected_client_stream_message));
1649
1650 // 3. Set up mocks.
1651 EXPECT_CALL(*mock_grpc_bidi_stream_,
1652 Send(Pointee(EqualsProto(expected_client_stream_message))))
1653 .WillOnce(Return(absl::AbortedError("foo")));
1654
1655 // 4. Test that ReportCompleted() sends the expected message.
1656 auto report_result = federated_protocol_->ReportCompleted(
1657 std::move(results), plan_duration, std::nullopt);
1658 EXPECT_THAT(report_result, IsCode(ABORTED));
1659 EXPECT_THAT(report_result.message(), HasSubstr("foo"));
1660
1661 // If we made it to the Report protocol phase, then the client must've been
1662 // accepted during the Checkin phase first, and so we should receive the
1663 // "accepted" RetryWindow.
1664 ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1665 }
1666
1667 // This function tests the happy path of ReportCompleted() - results get
1668 // reported, server replies with a RetryWindow.
TEST_P(GrpcFederatedProtocolTest,TestPublishReportSuccess)1669 TEST_P(GrpcFederatedProtocolTest, TestPublishReportSuccess) {
1670 // Issue an eligibility eval checkin first, followed by a successful checkin.
1671 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1672 ASSERT_OK(RunSuccessfulCheckin(/*use_secure_aggregation=*/false));
1673
1674 // 1. Create input for the Report function.
1675 ComputationResults results;
1676 results.emplace("tensorflow_checkpoint", "");
1677
1678 // 2. Set up mocks.
1679 EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
1680 .WillOnce(Return(absl::OkStatus()));
1681 ServerStreamMessage response_message;
1682 response_message.mutable_report_response();
1683 EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
1684 .WillOnce(
1685 DoAll(SetArgPointee<0>(response_message), Return(absl::OkStatus())));
1686
1687 // 3. Test that ReportCompleted() sends the expected message.
1688 auto report_result = federated_protocol_->ReportCompleted(
1689 std::move(results), absl::ZeroDuration(), std::nullopt);
1690 EXPECT_OK(report_result);
1691
1692 // If we made it to the Report protocol phase, then the client must've been
1693 // accepted during the Checkin phase first, and so we should receive the
1694 // "accepted" RetryWindow.
1695 ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1696 }
1697
1698 // This function tests the Send code path when PhaseOutcome indicates an
1699 // error. / In that case, no checkpoint, and only the duration stat, should be
1700 // uploaded.
TEST_P(GrpcFederatedProtocolTest,TestPublishReportNotCompleteSendFails)1701 TEST_P(GrpcFederatedProtocolTest, TestPublishReportNotCompleteSendFails) {
1702 // Issue an eligibility eval checkin first, followed by a successful checkin.
1703 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1704 ASSERT_OK(RunSuccessfulCheckin(/*use_secure_aggregation=*/false));
1705
1706 // 1. Create input for the Report function.
1707 absl::Duration plan_duration = absl::Milliseconds(1337);
1708
1709 // 2. The expected message sent to the server by the ReportNotCompleted()
1710 // function, as text proto.
1711 ClientStreamMessage expected_client_stream_message;
1712 ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(
1713 absl::StrCat("report_request {", " population_name: \"", kPopulationName,
1714 "\"", " execution_phase_id: \"", kExecutionPhaseId, "\"",
1715 " report {", " serialized_train_event {",
1716 "[type.googleapis.com/",
1717 "google.internal.federatedml.v2.ClientExecutionStats] {",
1718 " duration { seconds: 1 nanos: 337000000 }",
1719 " }", " }", " status_code: INTERNAL", " }", "}"),
1720 &expected_client_stream_message));
1721
1722 // 3. Set up mocks.
1723 EXPECT_CALL(*mock_grpc_bidi_stream_,
1724 Send(Pointee(EqualsProto(expected_client_stream_message))))
1725 .WillOnce(Return(absl::AbortedError("foo")));
1726
1727 // 4. Test that ReportNotCompleted() sends the expected message.
1728 auto report_result = federated_protocol_->ReportNotCompleted(
1729 engine::PhaseOutcome::ERROR, plan_duration, std::nullopt);
1730 EXPECT_THAT(report_result, IsCode(ABORTED));
1731 EXPECT_THAT(report_result.message(), HasSubstr("foo"));
1732
1733 // If we made it to the Report protocol phase, then the client must've been
1734 // accepted during the Checkin phase first, and so we should receive the
1735 // "accepted" RetryWindow.
1736 ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1737 }
1738
1739 // This function tests the happy path of ReportCompleted() - results get
1740 // reported, server replies with a RetryWindow.
TEST_P(GrpcFederatedProtocolTest,TestPublishReportSuccessCommitsToOpstats)1741 TEST_P(GrpcFederatedProtocolTest, TestPublishReportSuccessCommitsToOpstats) {
1742 // Issue an eligibility eval checkin first, followed by a successful checkin.
1743 ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1744 ASSERT_OK(RunSuccessfulCheckin(/*use_secure_aggregation=*/false));
1745
1746 // 1. Create input for the Report function.
1747 ComputationResults results;
1748 results.emplace("tensorflow_checkpoint", "");
1749
1750 // 2. Set up mocks.
1751 EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
1752 .WillOnce(Return(absl::OkStatus()));
1753 ServerStreamMessage response_message;
1754 response_message.mutable_report_response();
1755 EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
1756 .WillOnce(
1757 DoAll(SetArgPointee<0>(response_message), Return(absl::OkStatus())));
1758
1759 // 3. Test that ReportCompleted() sends the expected message.
1760 auto report_result = federated_protocol_->ReportCompleted(
1761 std::move(results), absl::ZeroDuration(), std::nullopt);
1762 EXPECT_OK(report_result);
1763
1764 // If we made it to the Report protocol phase, then the client must've been
1765 // accepted during the Checkin phase first, and so we should receive the
1766 // "accepted" RetryWindow.
1767 ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1768 }
1769
1770 } // anonymous namespace
1771 } // namespace fcp::client
1772