xref: /aosp_15_r20/external/federated-compute/fcp/client/grpc_federated_protocol_test.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "fcp/client/grpc_federated_protocol.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <tuple>
22 #include <utility>
23 
24 #include "google/protobuf/text_format.h"
25 #include "gmock/gmock.h"
26 #include "gtest/gtest.h"
27 #include "absl/random/random.h"
28 #include "absl/status/status.h"
29 #include "absl/synchronization/blocking_counter.h"
30 #include "absl/time/time.h"
31 #include "fcp/base/monitoring.h"
32 #include "fcp/client/cache/test_helpers.h"
33 #include "fcp/client/diag_codes.pb.h"
34 #include "fcp/client/engine/engine.pb.h"
35 #include "fcp/client/grpc_bidi_stream.h"
36 #include "fcp/client/http/http_client.h"
37 #include "fcp/client/http/testing/test_helpers.h"
38 #include "fcp/client/interruptible_runner.h"
39 #include "fcp/client/stats.h"
40 #include "fcp/client/test_helpers.h"
41 #include "fcp/protos/federated_api.pb.h"
42 #include "fcp/secagg/client/secagg_client.h"
43 #include "fcp/secagg/shared/aes_ctr_prng_factory.h"
44 #include "fcp/secagg/testing/fake_prng.h"
45 #include "fcp/secagg/testing/mock_send_to_server_interface.h"
46 #include "fcp/secagg/testing/mock_state_transition_listener.h"
47 #include "fcp/testing/testing.h"
48 
49 namespace fcp::client {
50 namespace {
51 
52 using ::fcp::EqualsProto;
53 using ::fcp::IsCode;
54 using ::fcp::client::http::FakeHttpResponse;
55 using ::fcp::client::http::HttpRequest;
56 using ::fcp::client::http::MockHttpClient;
57 using ::fcp::client::http::SimpleHttpRequestMatcher;
58 using ::google::internal::federatedml::v2::AcceptanceInfo;
59 using ::google::internal::federatedml::v2::CheckinRequest;
60 using ::google::internal::federatedml::v2::ClientStreamMessage;
61 using ::google::internal::federatedml::v2::EligibilityEvalCheckinRequest;
62 using ::google::internal::federatedml::v2::EligibilityEvalPayload;
63 using ::google::internal::federatedml::v2::HttpCompressionFormat;
64 using ::google::internal::federatedml::v2::ReportResponse;
65 using ::google::internal::federatedml::v2::RetryWindow;
66 using ::google::internal::federatedml::v2::ServerStreamMessage;
67 using ::google::internal::federatedml::v2::TaskEligibilityInfo;
68 using ::google::internal::federatedml::v2::TaskWeight;
69 using ::testing::_;
70 using ::testing::AllOf;
71 using ::testing::DoAll;
72 using ::testing::DoubleEq;
73 using ::testing::DoubleNear;
74 using ::testing::Eq;
75 using ::testing::Field;
76 using ::testing::FieldsAre;
77 using ::testing::Ge;
78 using ::testing::Gt;
79 using ::testing::HasSubstr;
80 using ::testing::InSequence;
81 using ::testing::IsEmpty;
82 using ::testing::Lt;
83 using ::testing::MockFunction;
84 using ::testing::NiceMock;
85 using ::testing::Not;
86 using ::testing::Optional;
87 using ::testing::Pair;
88 using ::testing::Pointee;
89 using ::testing::Return;
90 using ::testing::SetArgPointee;
91 using ::testing::StrictMock;
92 using ::testing::UnorderedElementsAre;
93 using ::testing::VariantWith;
94 
95 constexpr char kPopulationName[] = "TEST/POPULATION";
96 constexpr char kFederatedSelectUriTemplate[] = "https://federated.select";
97 constexpr char kExecutionPhaseId[] = "TEST/POPULATION/TEST_TASK#1234.ab35";
98 constexpr char kPlan[] = "CLIENT_ONLY_PLAN";
99 constexpr char kInitCheckpoint[] = "INIT_CHECKPOINT";
100 constexpr char kRetryToken[] = "OLD_RETRY_TOKEN";
101 constexpr char kClientVersion[] = "CLIENT_VERSION";
102 constexpr char kAttestationMeasurement[] = "ATTESTATION_MEASUREMENT";
103 constexpr int kSecAggExpectedNumberOfClients = 10;
104 constexpr int kSecAggMinSurvivingClientsForReconstruction = 8;
105 constexpr int kSecAggMinClientsInServerVisibleAggregate = 4;
106 
107 class MockGrpcBidiStream : public GrpcBidiStreamInterface {
108  public:
109   MOCK_METHOD(absl::Status, Send, (ClientStreamMessage*), (override));
110   MOCK_METHOD(absl::Status, Receive, (ServerStreamMessage*), (override));
111   MOCK_METHOD(void, Close, (), (override));
112   MOCK_METHOD(int64_t, ChunkingLayerBytesSent, (), (override));
113   MOCK_METHOD(int64_t, ChunkingLayerBytesReceived, (), (override));
114 };
115 
116 constexpr int kTransientErrorsRetryPeriodSecs = 10;
117 constexpr double kTransientErrorsRetryDelayJitterPercent = 0.1;
118 constexpr double kExpectedTransientErrorsRetryPeriodSecsMin = 9.0;
119 constexpr double kExpectedTransientErrorsRetryPeriodSecsMax = 11.0;
120 constexpr int kPermanentErrorsRetryPeriodSecs = 100;
121 constexpr double kPermanentErrorsRetryDelayJitterPercent = 0.2;
122 constexpr double kExpectedPermanentErrorsRetryPeriodSecsMin = 80.0;
123 constexpr double kExpectedPermanentErrorsRetryPeriodSecsMax = 120.0;
124 
ExpectTransientErrorRetryWindow(const RetryWindow & retry_window)125 void ExpectTransientErrorRetryWindow(const RetryWindow& retry_window) {
126   // The calculated retry delay must lie within the expected transient errors
127   // retry delay range.
128   EXPECT_THAT(retry_window.delay_min().seconds() +
129                   retry_window.delay_min().nanos() / 1000000000,
130               AllOf(Ge(kExpectedTransientErrorsRetryPeriodSecsMin),
131                     Lt(kExpectedTransientErrorsRetryPeriodSecsMax)));
132   EXPECT_THAT(retry_window.delay_max(), EqualsProto(retry_window.delay_min()));
133 }
134 
ExpectPermanentErrorRetryWindow(const RetryWindow & retry_window)135 void ExpectPermanentErrorRetryWindow(const RetryWindow& retry_window) {
136   // The calculated retry delay must lie within the expected permanent errors
137   // retry delay range.
138   EXPECT_THAT(retry_window.delay_min().seconds() +
139                   retry_window.delay_min().nanos() / 1000000000,
140               AllOf(Ge(kExpectedPermanentErrorsRetryPeriodSecsMin),
141                     Lt(kExpectedPermanentErrorsRetryPeriodSecsMax)));
142   EXPECT_THAT(retry_window.delay_max(), EqualsProto(retry_window.delay_min()));
143 }
144 
GetAcceptedRetryWindow()145 google::internal::federatedml::v2::RetryWindow GetAcceptedRetryWindow() {
146   google::internal::federatedml::v2::RetryWindow retry_window;
147   // Must not overlap with kTransientErrorsRetryPeriodSecs or
148   // kPermanentErrorsRetryPeriodSecs.
149   retry_window.mutable_delay_min()->set_seconds(200L);
150   retry_window.mutable_delay_max()->set_seconds(299L);
151   *retry_window.mutable_retry_token() = "RETRY_TOKEN_ACCEPTED";
152   return retry_window;
153 }
154 
GetRejectedRetryWindow()155 google::internal::federatedml::v2::RetryWindow GetRejectedRetryWindow() {
156   google::internal::federatedml::v2::RetryWindow retry_window;
157   // Must not overlap with kTransientErrorsRetryPeriodSecs or
158   // kPermanentErrorsRetryPeriodSecs.
159   retry_window.mutable_delay_min()->set_seconds(300);
160   retry_window.mutable_delay_max()->set_seconds(399L);
161   *retry_window.mutable_retry_token() = "RETRY_TOKEN_REJECTED";
162   return retry_window;
163 }
164 
ExpectAcceptedRetryWindow(const RetryWindow & retry_window)165 void ExpectAcceptedRetryWindow(const RetryWindow& retry_window) {
166   // The calculated retry delay must lie within the expected permanent errors
167   // retry delay range.
168   EXPECT_THAT(retry_window.delay_min().seconds() +
169                   retry_window.delay_min().nanos() / 1000000000,
170               AllOf(Ge(200), Lt(299L)));
171   EXPECT_THAT(retry_window.delay_max(), EqualsProto(retry_window.delay_min()));
172 }
173 
ExpectRejectedRetryWindow(const RetryWindow & retry_window)174 void ExpectRejectedRetryWindow(const RetryWindow& retry_window) {
175   // The calculated retry delay must lie within the expected permanent errors
176   // retry delay range.
177   EXPECT_THAT(retry_window.delay_min().seconds() +
178                   retry_window.delay_min().nanos() / 1000000000,
179               AllOf(Ge(300), Lt(399)));
180   EXPECT_THAT(retry_window.delay_max(), EqualsProto(retry_window.delay_min()));
181 }
182 
GetFakeCheckinRequestAck(const RetryWindow & accepted_retry_window=GetAcceptedRetryWindow (),const RetryWindow & rejected_retry_window=GetRejectedRetryWindow ())183 ServerStreamMessage GetFakeCheckinRequestAck(
184     const RetryWindow& accepted_retry_window = GetAcceptedRetryWindow(),
185     const RetryWindow& rejected_retry_window = GetRejectedRetryWindow()) {
186   ServerStreamMessage checkin_request_ack_message;
187   *checkin_request_ack_message.mutable_checkin_request_ack()
188        ->mutable_retry_window_if_accepted() = accepted_retry_window;
189   *checkin_request_ack_message.mutable_checkin_request_ack()
190        ->mutable_retry_window_if_rejected() = rejected_retry_window;
191   return checkin_request_ack_message;
192 }
193 
GetFakeEnabledEligibilityCheckinResponse(const std::string & plan,const std::string & init_checkpoint,const std::string & execution_id)194 ServerStreamMessage GetFakeEnabledEligibilityCheckinResponse(
195     const std::string& plan, const std::string& init_checkpoint,
196     const std::string& execution_id) {
197   ServerStreamMessage checkin_response_message;
198   EligibilityEvalPayload* eval_payload =
199       checkin_response_message.mutable_eligibility_eval_checkin_response()
200           ->mutable_eligibility_eval_payload();
201   eval_payload->set_plan(plan);
202   eval_payload->set_init_checkpoint(init_checkpoint);
203   eval_payload->set_execution_id(execution_id);
204   return checkin_response_message;
205 }
206 
GetFakeDisabledEligibilityCheckinResponse()207 ServerStreamMessage GetFakeDisabledEligibilityCheckinResponse() {
208   ServerStreamMessage checkin_response_message;
209   checkin_response_message.mutable_eligibility_eval_checkin_response()
210       ->mutable_no_eligibility_eval_configured();
211   return checkin_response_message;
212 }
213 
GetFakeRejectedEligibilityCheckinResponse()214 ServerStreamMessage GetFakeRejectedEligibilityCheckinResponse() {
215   ServerStreamMessage rejection_response_message;
216   rejection_response_message.mutable_eligibility_eval_checkin_response()
217       ->mutable_rejection_info();
218   return rejection_response_message;
219 }
220 
GetFakeTaskEligibilityInfo()221 TaskEligibilityInfo GetFakeTaskEligibilityInfo() {
222   TaskEligibilityInfo eligibility_info;
223   TaskWeight* task_weight = eligibility_info.mutable_task_weights()->Add();
224   task_weight->set_task_name("foo");
225   task_weight->set_weight(567.8);
226   return eligibility_info;
227 }
228 
GetFakeRejectedCheckinResponse()229 ServerStreamMessage GetFakeRejectedCheckinResponse() {
230   ServerStreamMessage rejection_response_message;
231   rejection_response_message.mutable_checkin_response()
232       ->mutable_rejection_info();
233   return rejection_response_message;
234 }
235 
GetFakeAcceptedCheckinResponse(const std::string & plan,const std::string & init_checkpoint,const std::string & federated_select_uri_template,const std::string & phase_id,bool use_secure_aggregation)236 ServerStreamMessage GetFakeAcceptedCheckinResponse(
237     const std::string& plan, const std::string& init_checkpoint,
238     const std::string& federated_select_uri_template,
239     const std::string& phase_id, bool use_secure_aggregation) {
240   ServerStreamMessage checkin_response_message;
241   AcceptanceInfo* acceptance_info =
242       checkin_response_message.mutable_checkin_response()
243           ->mutable_acceptance_info();
244   acceptance_info->set_plan(plan);
245   acceptance_info->set_execution_phase_id(phase_id);
246   acceptance_info->set_init_checkpoint(init_checkpoint);
247   acceptance_info->mutable_federated_select_uri_info()->set_uri_template(
248       federated_select_uri_template);
249   if (use_secure_aggregation) {
250     auto sec_agg =
251         acceptance_info->mutable_side_channel_protocol_execution_info()
252             ->mutable_secure_aggregation();
253     sec_agg->set_expected_number_of_clients(kSecAggExpectedNumberOfClients);
254     sec_agg->set_minimum_surviving_clients_for_reconstruction(
255         kSecAggMinSurvivingClientsForReconstruction);
256     sec_agg->set_minimum_clients_in_server_visible_aggregate(
257         kSecAggMinClientsInServerVisibleAggregate);
258     checkin_response_message.mutable_checkin_response()
259         ->mutable_protocol_options_response()
260         ->mutable_side_channels()
261         ->mutable_secure_aggregation()
262         ->set_client_variant(secagg::SECAGG_CLIENT_VARIANT_NATIVE_V1);
263   }
264   return checkin_response_message;
265 }
266 
GetFakeReportResponse()267 ServerStreamMessage GetFakeReportResponse() {
268   ServerStreamMessage report_response_message;
269   *report_response_message.mutable_report_response() = ReportResponse();
270   return report_response_message;
271 }
272 
GetExpectedEligibilityEvalCheckinRequest(bool enable_http_resource_support=false)273 ClientStreamMessage GetExpectedEligibilityEvalCheckinRequest(
274     bool enable_http_resource_support = false) {
275   ClientStreamMessage expected_message;
276   EligibilityEvalCheckinRequest* checkin_request =
277       expected_message.mutable_eligibility_eval_checkin_request();
278   checkin_request->set_population_name(kPopulationName);
279   checkin_request->set_client_version(kClientVersion);
280   checkin_request->set_retry_token(kRetryToken);
281   checkin_request->set_attestation_measurement(kAttestationMeasurement);
282   checkin_request->mutable_protocol_options_request()
283       ->mutable_side_channels()
284       ->mutable_secure_aggregation()
285       ->add_client_variant(secagg::SECAGG_CLIENT_VARIANT_NATIVE_V1);
286   checkin_request->mutable_protocol_options_request()->set_should_ack_checkin(
287       true);
288   checkin_request->mutable_protocol_options_request()
289       ->add_supported_http_compression_formats(
290           HttpCompressionFormat::HTTP_COMPRESSION_FORMAT_GZIP);
291 
292   if (enable_http_resource_support) {
293     checkin_request->mutable_protocol_options_request()
294         ->set_supports_http_download(true);
295     checkin_request->mutable_protocol_options_request()
296         ->set_supports_eligibility_eval_http_download(true);
297   }
298 
299   return expected_message;
300 }
301 
302 // This returns the CheckinRequest gRPC proto we expect each Checkin(...) call
303 // to result in.
GetExpectedCheckinRequest(const std::optional<TaskEligibilityInfo> & task_eligibility_info=std::nullopt,bool enable_http_resource_support=false)304 ClientStreamMessage GetExpectedCheckinRequest(
305     const std::optional<TaskEligibilityInfo>& task_eligibility_info =
306         std::nullopt,
307     bool enable_http_resource_support = false) {
308   ClientStreamMessage expected_message;
309   CheckinRequest* checkin_request = expected_message.mutable_checkin_request();
310   checkin_request->set_population_name(kPopulationName);
311   checkin_request->set_client_version(kClientVersion);
312   checkin_request->set_retry_token(kRetryToken);
313   checkin_request->set_attestation_measurement(kAttestationMeasurement);
314   checkin_request->mutable_protocol_options_request()
315       ->mutable_side_channels()
316       ->mutable_secure_aggregation()
317       ->add_client_variant(secagg::SECAGG_CLIENT_VARIANT_NATIVE_V1);
318   checkin_request->mutable_protocol_options_request()->set_should_ack_checkin(
319       false);
320   checkin_request->mutable_protocol_options_request()
321       ->add_supported_http_compression_formats(
322           HttpCompressionFormat::HTTP_COMPRESSION_FORMAT_GZIP);
323 
324   if (enable_http_resource_support) {
325     checkin_request->mutable_protocol_options_request()
326         ->set_supports_http_download(true);
327     checkin_request->mutable_protocol_options_request()
328         ->set_supports_eligibility_eval_http_download(true);
329   }
330 
331   if (task_eligibility_info.has_value()) {
332     *checkin_request->mutable_task_eligibility_info() = *task_eligibility_info;
333   }
334   return expected_message;
335 }
336 
337 class GrpcFederatedProtocolTest
338     // The first parameter indicates whether support for HTTP task resources
339     // should be enabled.
340     : public testing::TestWithParam<bool> {
341  public:
GrpcFederatedProtocolTest()342   GrpcFederatedProtocolTest() {
343     // The gRPC stream should always be closed at the end of all tests.
344     EXPECT_CALL(*mock_grpc_bidi_stream_, Close());
345   }
346 
347  protected:
SetUp()348   void SetUp() override {
349     enable_http_resource_support_ = GetParam();
350     EXPECT_CALL(*mock_grpc_bidi_stream_, ChunkingLayerBytesReceived())
351         .WillRepeatedly(Return(0));
352     EXPECT_CALL(*mock_grpc_bidi_stream_, ChunkingLayerBytesSent())
353         .WillRepeatedly(Return(0));
354     EXPECT_CALL(mock_flags_,
355                 federated_training_transient_errors_retry_delay_secs)
356         .WillRepeatedly(Return(kTransientErrorsRetryPeriodSecs));
357     EXPECT_CALL(mock_flags_,
358                 federated_training_transient_errors_retry_delay_jitter_percent)
359         .WillRepeatedly(Return(kTransientErrorsRetryDelayJitterPercent));
360     EXPECT_CALL(mock_flags_,
361                 federated_training_permanent_errors_retry_delay_secs)
362         .WillRepeatedly(Return(kPermanentErrorsRetryPeriodSecs));
363     EXPECT_CALL(mock_flags_,
364                 federated_training_permanent_errors_retry_delay_jitter_percent)
365         .WillRepeatedly(Return(kPermanentErrorsRetryDelayJitterPercent));
366     EXPECT_CALL(mock_flags_, federated_training_permanent_error_codes)
367         .WillRepeatedly(Return(std::vector<int32_t>{
368             static_cast<int32_t>(absl::StatusCode::kNotFound),
369             static_cast<int32_t>(absl::StatusCode::kInvalidArgument),
370             static_cast<int32_t>(absl::StatusCode::kUnimplemented)}));
371     EXPECT_CALL(mock_flags_,
372                 enable_grpc_with_eligibility_eval_http_resource_support)
373         .WillRepeatedly(Return(enable_http_resource_support_));
374 
375     // We only initialize federated_protocol_ in this SetUp method, rather than
376     // in the test's constructor, to ensure that we can set mock flag values
377     // before the GrpcFederatedProtocol constructor is called. Using
378     // std::unique_ptr conveniently allows us to assign the field a new value
379     // after construction (which we could not do if the field's type was
380     // GrpcFederatedProtocol, since it doesn't have copy or move constructors).
381     federated_protocol_ = std::make_unique<GrpcFederatedProtocol>(
382         &mock_event_publisher_, &mock_log_manager_,
383         absl::WrapUnique(mock_secagg_runner_factory_), &mock_flags_,
384         /*http_client=*/
385         enable_http_resource_support_ ? &mock_http_client_ : nullptr,
386         // We want to inject mocks stored in unique_ptrs to the
387         // class-under-test, hence we transfer ownership via WrapUnique. To
388         // write expectations for the mock, we retain the raw pointer to it,
389         // which will be valid until GrpcFederatedProtocol's d'tor is called.
390         absl::WrapUnique(mock_grpc_bidi_stream_), kPopulationName, kRetryToken,
391         kClientVersion, kAttestationMeasurement,
392         mock_should_abort_.AsStdFunction(), absl::BitGen(),
393         InterruptibleRunner::TimingConfig{
394             .polling_period = absl::ZeroDuration(),
395             .graceful_shutdown_period = absl::InfiniteDuration(),
396             .extended_shutdown_period = absl::InfiniteDuration()},
397         &mock_resource_cache_);
398   }
399 
TearDown()400   void TearDown() override {
401     fcp::client::http::HttpRequestHandle::SentReceivedBytes
402         sent_received_bytes = mock_http_client_.TotalSentReceivedBytes();
403 
404     NetworkStats network_stats = federated_protocol_->GetNetworkStats();
405     EXPECT_THAT(network_stats.bytes_downloaded,
406                 Ge(mock_grpc_bidi_stream_->ChunkingLayerBytesReceived() +
407                    sent_received_bytes.received_bytes));
408     EXPECT_THAT(network_stats.bytes_uploaded,
409                 Ge(mock_grpc_bidi_stream_->ChunkingLayerBytesSent() +
410                    sent_received_bytes.sent_bytes));
411     // If any network traffic occurred, we expect to see some time reflected in
412     // the duration (if the flag is on).
413     if (network_stats.bytes_uploaded > 0) {
414       EXPECT_THAT(network_stats.network_duration, Gt(absl::ZeroDuration()));
415     }
416   }
417 
418   // This function runs a successful
419   // EligibilityEvalCheckin(mock_eet_received_callback_.AsStdFunction()) that
420   // results in an eligibility eval payload being returned by the server. This
421   // is a utility function used by Checkin*() tests that depend on a prior,
422   // successful execution of
423   // EligibilityEvalCheckin(mock_eet_received_callback_.AsStdFunction()). It
424   // returns a absl::Status, which the caller should verify is OK using
425   // ASSERT_OK.
RunSuccessfulEligibilityEvalCheckin(bool eligibility_eval_enabled=true,const RetryWindow & accepted_retry_window=GetAcceptedRetryWindow (),const RetryWindow & rejected_retry_window=GetRejectedRetryWindow ())426   absl::Status RunSuccessfulEligibilityEvalCheckin(
427       bool eligibility_eval_enabled = true,
428       const RetryWindow& accepted_retry_window = GetAcceptedRetryWindow(),
429       const RetryWindow& rejected_retry_window = GetRejectedRetryWindow()) {
430     EXPECT_CALL(
431         *mock_grpc_bidi_stream_,
432         Send(Pointee(EqualsProto(GetExpectedEligibilityEvalCheckinRequest(
433             enable_http_resource_support_)))))
434         .WillOnce(Return(absl::OkStatus()));
435 
436     const std::string expected_execution_id = "ELIGIBILITY_EVAL_EXECUTION_ID";
437     EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
438         .WillOnce(DoAll(SetArgPointee<0>(GetFakeCheckinRequestAck(
439                             accepted_retry_window, rejected_retry_window)),
440                         Return(absl::OkStatus())))
441         .WillOnce(
442             DoAll(SetArgPointee<0>(
443                       eligibility_eval_enabled
444                           ? GetFakeEnabledEligibilityCheckinResponse(
445                                 kPlan, kInitCheckpoint, expected_execution_id)
446                           : GetFakeDisabledEligibilityCheckinResponse()),
447                   Return(absl::OkStatus())));
448 
449     return federated_protocol_
450         ->EligibilityEvalCheckin(mock_eet_received_callback_.AsStdFunction())
451         .status();
452   }
453 
454   // This function runs a successful Checkin() that results in acceptance by the
455   // server. This is a utility function used by Report*() tests that depend on a
456   // prior, successful execution of Checkin().
457   // It returns a absl::Status, which the caller should verify is OK using
458   // ASSERT_OK.
RunSuccessfulCheckin(bool use_secure_aggregation,const std::optional<TaskEligibilityInfo> & task_eligibility_info=GetFakeTaskEligibilityInfo ())459   absl::StatusOr<FederatedProtocol::CheckinResult> RunSuccessfulCheckin(
460       bool use_secure_aggregation,
461       const std::optional<TaskEligibilityInfo>& task_eligibility_info =
462           GetFakeTaskEligibilityInfo()) {
463     EXPECT_CALL(*mock_grpc_bidi_stream_,
464                 Send(Pointee(EqualsProto(GetExpectedCheckinRequest(
465                     task_eligibility_info, enable_http_resource_support_)))))
466         .WillOnce(Return(absl::OkStatus()));
467 
468     {
469       InSequence seq;
470       EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
471           .WillOnce(
472               DoAll(SetArgPointee<0>(GetFakeAcceptedCheckinResponse(
473                         kPlan, kInitCheckpoint, kFederatedSelectUriTemplate,
474                         kExecutionPhaseId, use_secure_aggregation)),
475                     Return(absl::OkStatus())))
476           .RetiresOnSaturation();
477     }
478 
479     return federated_protocol_->Checkin(
480         task_eligibility_info, mock_task_received_callback_.AsStdFunction());
481   }
482 
483   // See note in the constructor for why these are pointers.
484   StrictMock<MockGrpcBidiStream>* mock_grpc_bidi_stream_ =
485       new StrictMock<MockGrpcBidiStream>();
486 
487   StrictMock<MockEventPublisher> mock_event_publisher_;
488   NiceMock<MockLogManager> mock_log_manager_;
489   StrictMock<MockSecAggRunnerFactory>* mock_secagg_runner_factory_ =
490       new StrictMock<MockSecAggRunnerFactory>();
491   StrictMock<MockSecAggRunner>* mock_secagg_runner_;
492   NiceMock<MockFlags> mock_flags_;
493   StrictMock<MockHttpClient> mock_http_client_;
494   NiceMock<MockFunction<bool()>> mock_should_abort_;
495   StrictMock<cache::MockResourceCache> mock_resource_cache_;
496   NiceMock<MockFunction<void(
497       const ::fcp::client::FederatedProtocol::EligibilityEvalTask&)>>
498       mock_eet_received_callback_;
499   NiceMock<MockFunction<void(
500       const ::fcp::client::FederatedProtocol::TaskAssignment&)>>
501       mock_task_received_callback_;
502 
503   // The class under test.
504   std::unique_ptr<GrpcFederatedProtocol> federated_protocol_;
505   bool enable_http_resource_support_;
506 };
507 
GenerateTestName(const testing::TestParamInfo<GrpcFederatedProtocolTest::ParamType> & info)508 std::string GenerateTestName(
509     const testing::TestParamInfo<GrpcFederatedProtocolTest::ParamType>& info) {
510   std::string name = info.param ? "Http_resource_support_enabled"
511                                 : "Http_resource_support_disabled";
512   return name;
513 }
514 
515 INSTANTIATE_TEST_SUITE_P(NewVsOldBehavior, GrpcFederatedProtocolTest,
516                          testing::Bool(), GenerateTestName);
517 
518 using GrpcFederatedProtocolDeathTest = GrpcFederatedProtocolTest;
519 INSTANTIATE_TEST_SUITE_P(NewVsOldBehavior, GrpcFederatedProtocolDeathTest,
520                          testing::Bool(), GenerateTestName);
521 
TEST_P(GrpcFederatedProtocolTest,TestTransientErrorRetryWindowDifferentAcrossDifferentInstances)522 TEST_P(GrpcFederatedProtocolTest,
523        TestTransientErrorRetryWindowDifferentAcrossDifferentInstances) {
524   const RetryWindow& retry_window1 =
525       federated_protocol_->GetLatestRetryWindow();
526   ExpectTransientErrorRetryWindow(retry_window1);
527   federated_protocol_.reset(nullptr);
528 
529   mock_grpc_bidi_stream_ = new StrictMock<MockGrpcBidiStream>();
530   EXPECT_CALL(*mock_grpc_bidi_stream_, ChunkingLayerBytesReceived())
531       .WillRepeatedly(Return(0));
532   EXPECT_CALL(*mock_grpc_bidi_stream_, ChunkingLayerBytesSent())
533       .WillRepeatedly(Return(0));
534   EXPECT_CALL(*mock_grpc_bidi_stream_, Close());
535   mock_secagg_runner_factory_ = new StrictMock<MockSecAggRunnerFactory>();
536   // Create a new GrpcFederatedProtocol instance. It should not produce the same
537   // retry window value as the one we just got. This is a simple correctness
538   // check to ensure that the value is at least randomly generated (and that we
539   // don't accidentally use the random number generator incorrectly).
540   federated_protocol_ = std::make_unique<GrpcFederatedProtocol>(
541       &mock_event_publisher_, &mock_log_manager_,
542       absl::WrapUnique(mock_secagg_runner_factory_), &mock_flags_,
543       /*http_client=*/nullptr, absl::WrapUnique(mock_grpc_bidi_stream_),
544       kPopulationName, kRetryToken, kClientVersion, kAttestationMeasurement,
545       mock_should_abort_.AsStdFunction(), absl::BitGen(),
546       InterruptibleRunner::TimingConfig{
547           .polling_period = absl::ZeroDuration(),
548           .graceful_shutdown_period = absl::InfiniteDuration(),
549           .extended_shutdown_period = absl::InfiniteDuration()},
550       &mock_resource_cache_);
551 
552   const RetryWindow& retry_window2 =
553       federated_protocol_->GetLatestRetryWindow();
554   ExpectTransientErrorRetryWindow(retry_window2);
555 
556   EXPECT_THAT(retry_window1, Not(EqualsProto(retry_window2)));
557 }
558 
TEST_P(GrpcFederatedProtocolTest,TestEligibilityEvalCheckinSendFailsTransientError)559 TEST_P(GrpcFederatedProtocolTest,
560        TestEligibilityEvalCheckinSendFailsTransientError) {
561   // Make the gRPC stream return an UNAVAILABLE error when the
562   // EligibilityEvalCheckin(...) code tries to send its first message. This
563   // should result in the error being returned as the result.
564   EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
565       .WillOnce(Return(absl::UnavailableError("foo")));
566 
567   auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
568       mock_eet_received_callback_.AsStdFunction());
569 
570   EXPECT_THAT(eligibility_checkin_result.status(), IsCode(UNAVAILABLE));
571   EXPECT_THAT(eligibility_checkin_result.status().message(), "foo");
572   // No RetryWindows were received from the server, so we expect to get a
573   // RetryWindow generated based on the transient errors retry delay flag.
574   ExpectTransientErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
575 }
576 
TEST_P(GrpcFederatedProtocolTest,TestEligibilityEvalCheckinSendFailsPermanentError)577 TEST_P(GrpcFederatedProtocolTest,
578        TestEligibilityEvalCheckinSendFailsPermanentError) {
579   // Make the gRPC stream return an NOT_FOUND error when the
580   // EligibilityEvalCheckin(...) code tries to send its first message. This
581   // should result in the error being returned as the result.
582   EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
583       .WillOnce(Return(absl::NotFoundError("foo")));
584 
585   auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
586       mock_eet_received_callback_.AsStdFunction());
587 
588   EXPECT_THAT(eligibility_checkin_result.status(), IsCode(NOT_FOUND));
589   EXPECT_THAT(eligibility_checkin_result.status().message(), "foo");
590   // No RetryWindows were received from the server, so we expect to get a
591   // RetryWindow generated based on the *permanent* errors retry delay flag,
592   // since NOT_FOUND is marked as a permanent error in the flags.
593   ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
594 }
595 
596 // Tests the case where the blocking Send() call in EligibilityEvalCheckin is
597 // interrupted.
TEST_P(GrpcFederatedProtocolTest,TestEligibilityEvalCheckinSendInterrupted)598 TEST_P(GrpcFederatedProtocolTest, TestEligibilityEvalCheckinSendInterrupted) {
599   absl::BlockingCounter counter_should_abort(1);
600 
601   // Make Send() block until the counter is decremented.
602   EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
603       .WillOnce([&counter_should_abort](ClientStreamMessage* ignored) {
604         counter_should_abort.Wait();
605         return absl::OkStatus();
606       });
607   // Make should_abort return false for the first two calls, and then make it
608   // decrement the counter and return true, triggering an abort sequence and
609   // unblocking the Send() call we caused to block above.
610   EXPECT_CALL(mock_should_abort_, Call())
611       .WillOnce(Return(false))
612       .WillOnce(Return(false))
613       .WillRepeatedly([&counter_should_abort] {
614         counter_should_abort.DecrementCount();
615         return true;
616       });
617   // In addition to the Close() call we expect in the test fixture above, expect
618   // an additional one (the one that induced the abort).
619   EXPECT_CALL(*mock_grpc_bidi_stream_, Close()).Times(1).RetiresOnSaturation();
620   EXPECT_CALL(mock_log_manager_,
621               LogDiag(ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_GRPC));
622 
623   auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
624       mock_eet_received_callback_.AsStdFunction());
625 
626   EXPECT_THAT(eligibility_checkin_result.status(), IsCode(CANCELLED));
627   // No RetryWindows were received from the server, so we expect to get a
628   // RetryWindow generated based on the transient errors retry delay flag.
629   ExpectTransientErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
630 }
631 
632 // If a CheckinRequestAck is requested in the ProtocolOptionsRequest but not
633 // received, UNIMPLEMENTED should be returned.
TEST_P(GrpcFederatedProtocolTest,TestEligibilityEvalCheckinMissingCheckinRequestAck)634 TEST_P(GrpcFederatedProtocolTest,
635        TestEligibilityEvalCheckinMissingCheckinRequestAck) {
636   // We immediately return an EligibilityEvalCheckinResponse, rather than
637   // returning a CheckinRequestAck first.
638   EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
639       .WillOnce(Return(absl::OkStatus()));
640   EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
641       .WillOnce(
642           DoAll(SetArgPointee<0>(GetFakeRejectedEligibilityCheckinResponse()),
643                 Return(absl::OkStatus())));
644   EXPECT_CALL(
645       mock_log_manager_,
646       LogDiag(
647           ProdDiagCode::
648               BACKGROUND_TRAINING_CHECKIN_REQUEST_ACK_EXPECTED_BUT_NOT_RECVD));  // NOLINT
649 
650   auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
651       mock_eet_received_callback_.AsStdFunction());
652 
653   EXPECT_THAT(eligibility_checkin_result.status(), IsCode(UNIMPLEMENTED));
654   // No RetryWindows were received from the server, so we expect to get a
655   // RetryWindow generated based on the *permanent* errors retry delay flag,
656   // since UNIMPLEMENTED is marked as a permanent error in the flags.
657   ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
658 }
659 
TEST_P(GrpcFederatedProtocolTest,TestEligibilityEvalCheckinWaitForCheckinRequestAckFails)660 TEST_P(GrpcFederatedProtocolTest,
661        TestEligibilityEvalCheckinWaitForCheckinRequestAckFails) {
662   EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
663       .WillOnce(Return(absl::OkStatus()));
664 
665   // Make the very first Receive() call fail (i.e. the one expecting the
666   // CheckinRequestAck).
667   std::string expected_message = "foo";
668   EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
669       .WillOnce(Return(absl::AbortedError(expected_message)));
670 
671   auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
672       mock_eet_received_callback_.AsStdFunction());
673 
674   EXPECT_THAT(eligibility_checkin_result.status(), IsCode(ABORTED));
675   EXPECT_THAT(eligibility_checkin_result.status().message(), expected_message);
676   // No RetryWindows were received from the server, so we expect to get a
677   // RetryWindow generated based on the transient errors retry delay flag.
678   ExpectTransientErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
679 }
680 
TEST_P(GrpcFederatedProtocolTest,TestEligibilityEvalCheckinWaitForCheckinResponseFails)681 TEST_P(GrpcFederatedProtocolTest,
682        TestEligibilityEvalCheckinWaitForCheckinResponseFails) {
683   EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
684       .WillOnce(Return(absl::OkStatus()));
685 
686   // Failed checkins that have received an ack already should return the
687   // rejected retry window.
688   std::string expected_message = "foo";
689   EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
690       .WillOnce(DoAll(SetArgPointee<0>(GetFakeCheckinRequestAck()),
691                       Return(absl::OkStatus())))
692       // Make the second Receive() call fail (i.e. the one expecting the
693       // EligibilityEvalCheckinResponse).
694       .WillOnce(Return(absl::AbortedError(expected_message)));
695 
696   auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
697       mock_eet_received_callback_.AsStdFunction());
698 
699   EXPECT_THAT(eligibility_checkin_result.status(), IsCode(ABORTED));
700   EXPECT_THAT(eligibility_checkin_result.status().message(), expected_message);
701   ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
702 }
703 
TEST_P(GrpcFederatedProtocolTest,TestEligibilityEvalCheckinRejection)704 TEST_P(GrpcFederatedProtocolTest, TestEligibilityEvalCheckinRejection) {
705   EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
706       .WillOnce(Return(absl::OkStatus()));
707 
708   EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
709       .WillOnce(DoAll(SetArgPointee<0>(GetFakeCheckinRequestAck()),
710                       Return(absl::OkStatus())))
711       .WillOnce(
712           DoAll(SetArgPointee<0>(GetFakeRejectedEligibilityCheckinResponse()),
713                 Return(absl::OkStatus())));
714 
715   // The 'eet received' callback should not be invoked since no EET was given to
716   // the client.
717   EXPECT_CALL(mock_eet_received_callback_, Call(_)).Times(0);
718 
719   auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
720       mock_eet_received_callback_.AsStdFunction());
721 
722   ASSERT_OK(eligibility_checkin_result);
723   EXPECT_THAT(*eligibility_checkin_result,
724               VariantWith<FederatedProtocol::Rejection>(_));
725   ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
726 }
727 
TEST_P(GrpcFederatedProtocolTest,TestEligibilityEvalCheckinDisabled)728 TEST_P(GrpcFederatedProtocolTest, TestEligibilityEvalCheckinDisabled) {
729   EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
730       .WillOnce(Return(absl::OkStatus()));
731 
732   EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
733       .WillOnce(DoAll(SetArgPointee<0>(GetFakeCheckinRequestAck()),
734                       Return(absl::OkStatus())))
735       .WillOnce(
736           DoAll(SetArgPointee<0>(GetFakeDisabledEligibilityCheckinResponse()),
737                 Return(absl::OkStatus())));
738 
739   // The 'eet received' callback should not be invoked since no EET was given to
740   // the client.
741   EXPECT_CALL(mock_eet_received_callback_, Call(_)).Times(0);
742 
743   auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
744       mock_eet_received_callback_.AsStdFunction());
745 
746   ASSERT_OK(eligibility_checkin_result);
747   EXPECT_THAT(*eligibility_checkin_result,
748               VariantWith<FederatedProtocol::EligibilityEvalDisabled>(_));
749   ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
750 }
751 
TEST_P(GrpcFederatedProtocolTest,TestEligibilityEvalCheckinEnabled)752 TEST_P(GrpcFederatedProtocolTest, TestEligibilityEvalCheckinEnabled) {
753   // Note that in this particular test we check that the eligibility eval
754   // checkin request is as expected (in all prior tests we just use the '_'
755   // matcher, because the request isn't really relevant to the test).
756   EXPECT_CALL(*mock_grpc_bidi_stream_,
757               Send(Pointee(EqualsProto(GetExpectedEligibilityEvalCheckinRequest(
758                   enable_http_resource_support_)))))
759       .WillOnce(Return(absl::OkStatus()));
760 
761   // The EligibilityEvalCheckin(...) method should return the rejected
762   // RetryWindow, since after merely completing an eligibility eval checkin the
763   // client hasn't actually been accepted to a specific task yet.
764   std::string expected_plan = kPlan;
765   std::string expected_checkpoint = kInitCheckpoint;
766   std::string expected_execution_id = "ELIGIBILITY_EVAL_EXECUTION_ID";
767   EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
768       .WillOnce(DoAll(SetArgPointee<0>(GetFakeCheckinRequestAck()),
769                       Return(absl::OkStatus())))
770       .WillOnce(
771           DoAll(SetArgPointee<0>(GetFakeEnabledEligibilityCheckinResponse(
772                     expected_plan, expected_checkpoint, expected_execution_id)),
773                 Return(absl::OkStatus())));
774   EXPECT_CALL(
775       mock_log_manager_,
776       LogDiag(ProdDiagCode::BACKGROUND_TRAINING_CHECKIN_REQUEST_ACK_RECEIVED));
777 
778   // The 'EET received' callback should be called, even if the task resource
779   // data was available inline.
780   EXPECT_CALL(mock_eet_received_callback_,
781               Call(FieldsAre(FieldsAre("", ""), expected_execution_id,
782                              Eq(std::nullopt))));
783 
784   auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
785       mock_eet_received_callback_.AsStdFunction());
786 
787   ASSERT_OK(eligibility_checkin_result);
788   // If HTTP support is enabled then the checkpoint data gets returned in the
789   // shape of an absl::Cord (rather than an std::string), regardless of
790   // whether it was actually downloaded via HTTP.
791   if (enable_http_resource_support_) {
792     EXPECT_THAT(*eligibility_checkin_result,
793                 VariantWith<FederatedProtocol::EligibilityEvalTask>(
794                     FieldsAre(FieldsAre(absl::Cord(expected_plan),
795                                         absl::Cord(expected_checkpoint)),
796                               expected_execution_id, Eq(std::nullopt))));
797   } else {
798     EXPECT_THAT(*eligibility_checkin_result,
799                 VariantWith<FederatedProtocol::EligibilityEvalTask>(
800                     FieldsAre(FieldsAre(expected_plan, expected_checkpoint),
801                               expected_execution_id, Eq(std::nullopt))));
802   }
803   ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
804 }
805 
TEST_P(GrpcFederatedProtocolTest,TestEligiblityEvalCheckinEnabledWithHttpResourcesDownloaded)806 TEST_P(GrpcFederatedProtocolTest,
807        TestEligiblityEvalCheckinEnabledWithHttpResourcesDownloaded) {
808   if (!enable_http_resource_support_) {
809     GTEST_SKIP() << "This test only applies if the HTTP task resources feature "
810                     "is enabled";
811     return;
812   }
813 
814   EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
815       .WillOnce(Return(absl::OkStatus()));
816 
817   std::string expected_plan = kPlan;
818   std::string plan_uri = "https://fake.uri/plan";
819   std::string expected_checkpoint = kInitCheckpoint;
820   std::string checkpoint_uri = "https://fake.uri/checkpoint";
821   std::string expected_execution_id = "ELIGIBILITY_EVAL_EXECUTION_ID";
822   ServerStreamMessage fake_response = GetFakeEnabledEligibilityCheckinResponse(
823       /*plan=*/"", /*init_checkpoint=*/"", expected_execution_id);
824   EligibilityEvalPayload* eligibility_eval_payload =
825       fake_response.mutable_eligibility_eval_checkin_response()
826           ->mutable_eligibility_eval_payload();
827   eligibility_eval_payload->mutable_plan_resource()->set_uri(plan_uri);
828   eligibility_eval_payload->mutable_init_checkpoint_resource()->set_uri(
829       checkpoint_uri);
830 
831   EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
832       .WillOnce(DoAll(SetArgPointee<0>(GetFakeCheckinRequestAck()),
833                       Return(absl::OkStatus())))
834       .WillOnce(
835           DoAll(SetArgPointee<0>(fake_response), Return(absl::OkStatus())));
836   EXPECT_CALL(
837       mock_log_manager_,
838       LogDiag(ProdDiagCode::BACKGROUND_TRAINING_CHECKIN_REQUEST_ACK_RECEIVED));
839 
840   {
841     InSequence seq;
842     // The 'EET received' callback should be called *before* the actual task
843     // resources are fetched.
844     EXPECT_CALL(mock_eet_received_callback_,
845                 Call(FieldsAre(FieldsAre("", ""), expected_execution_id,
846                                Eq(std::nullopt))));
847 
848     EXPECT_CALL(mock_http_client_,
849                 PerformSingleRequest(SimpleHttpRequestMatcher(
850                     plan_uri, HttpRequest::Method::kGet, _, "")))
851         .WillOnce(Return(FakeHttpResponse(200, {}, expected_plan)));
852 
853     EXPECT_CALL(mock_http_client_,
854                 PerformSingleRequest(SimpleHttpRequestMatcher(
855                     checkpoint_uri, HttpRequest::Method::kGet, _, "")))
856         .WillOnce(Return(FakeHttpResponse(200, {}, expected_checkpoint)));
857   }
858 
859   {
860     InSequence seq;
861     EXPECT_CALL(
862         mock_log_manager_,
863         LogDiag(
864             ProdDiagCode::HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_USES_HTTP));
865     EXPECT_CALL(
866         mock_log_manager_,
867         LogDiag(
868             ProdDiagCode::
869                 HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_SUCCEEDED));
870   }
871 
872   // Issue the Eligibility Eval checkin.
873   auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
874       mock_eet_received_callback_.AsStdFunction());
875 
876   ASSERT_OK(eligibility_checkin_result);
877   EXPECT_THAT(
878       *eligibility_checkin_result,
879       VariantWith<FederatedProtocol::EligibilityEvalTask>(FieldsAre(
880           FieldsAre(absl::Cord(expected_plan), absl::Cord(expected_checkpoint)),
881           expected_execution_id, Eq(std::nullopt))));
882 }
883 
TEST_P(GrpcFederatedProtocolTest,TestEligiblityEvalCheckinEnabledWithHttpResourcesPlanDataFetchFailed)884 TEST_P(GrpcFederatedProtocolTest,
885        TestEligiblityEvalCheckinEnabledWithHttpResourcesPlanDataFetchFailed) {
886   if (!enable_http_resource_support_) {
887     GTEST_SKIP() << "This test only applies if the HTTP task resources feature "
888                     "is enabled";
889     return;
890   }
891 
892   EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
893       .WillOnce(Return(absl::OkStatus()));
894 
895   std::string expected_plan = kPlan;
896   std::string plan_uri = "https://fake.uri/plan";
897   std::string expected_checkpoint = kInitCheckpoint;
898   std::string expected_execution_id = "ELIGIBILITY_EVAL_EXECUTION_ID";
899   ServerStreamMessage fake_response = GetFakeEnabledEligibilityCheckinResponse(
900       /*plan=*/"", expected_checkpoint, expected_execution_id);
901   fake_response.mutable_eligibility_eval_checkin_response()
902       ->mutable_eligibility_eval_payload()
903       ->mutable_plan_resource()
904       ->set_uri(plan_uri);
905   EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
906       .WillOnce(DoAll(SetArgPointee<0>(GetFakeCheckinRequestAck()),
907                       Return(absl::OkStatus())))
908       .WillOnce(
909           DoAll(SetArgPointee<0>(fake_response), Return(absl::OkStatus())));
910   EXPECT_CALL(
911       mock_log_manager_,
912       LogDiag(ProdDiagCode::BACKGROUND_TRAINING_CHECKIN_REQUEST_ACK_RECEIVED));
913 
914   EXPECT_CALL(mock_http_client_,
915               PerformSingleRequest(SimpleHttpRequestMatcher(
916                   plan_uri, HttpRequest::Method::kGet, _, "")))
917       .WillOnce(Return(FakeHttpResponse(404, {}, "")));
918 
919   {
920     InSequence seq;
921     EXPECT_CALL(
922         mock_log_manager_,
923         LogDiag(
924             ProdDiagCode::HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_USES_HTTP));
925     EXPECT_CALL(
926         mock_log_manager_,
927         LogDiag(
928             ProdDiagCode::
929                 HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_FAILED));
930   }
931 
932   // Issue the eligibility eval checkin.
933   auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
934       mock_eet_received_callback_.AsStdFunction());
935 
936   EXPECT_THAT(eligibility_checkin_result.status(), IsCode(NOT_FOUND));
937   EXPECT_THAT(eligibility_checkin_result.status().message(),
938               HasSubstr("plan fetch failed"));
939   EXPECT_THAT(eligibility_checkin_result.status().message(), HasSubstr("404"));
940   // The EligibilityEvalCheckin call is expected to return the permanent error
941   // retry window, since 404 maps to a permanent error.
942   ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
943 }
944 
TEST_P(GrpcFederatedProtocolTest,TestEligiblityEvalCheckinEnabledWithHttpResourcesCheckpointFetchFailed)945 TEST_P(GrpcFederatedProtocolTest,
946        TestEligiblityEvalCheckinEnabledWithHttpResourcesCheckpointFetchFailed) {
947   if (!enable_http_resource_support_) {
948     GTEST_SKIP() << "This test only applies if the HTTP task resources feature "
949                     "is enabled";
950     return;
951   }
952 
953   EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
954       .WillOnce(Return(absl::OkStatus()));
955 
956   std::string expected_plan = kPlan;
957   std::string expected_checkpoint = kInitCheckpoint;
958   std::string checkpoint_uri = "https://fake.uri/checkpoint";
959   std::string expected_execution_id = "ELIGIBILITY_EVAL_EXECUTION_ID";
960   ServerStreamMessage fake_response = GetFakeEnabledEligibilityCheckinResponse(
961       expected_plan, /*init_checkpoint=*/"", expected_execution_id);
962   fake_response.mutable_eligibility_eval_checkin_response()
963       ->mutable_eligibility_eval_payload()
964       ->mutable_init_checkpoint_resource()
965       ->set_uri(checkpoint_uri);
966   EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
967       .WillOnce(DoAll(SetArgPointee<0>(GetFakeCheckinRequestAck()),
968                       Return(absl::OkStatus())))
969       .WillOnce(
970           DoAll(SetArgPointee<0>(fake_response), Return(absl::OkStatus())));
971   EXPECT_CALL(
972       mock_log_manager_,
973       LogDiag(ProdDiagCode::BACKGROUND_TRAINING_CHECKIN_REQUEST_ACK_RECEIVED));
974 
975   EXPECT_CALL(mock_http_client_,
976               PerformSingleRequest(SimpleHttpRequestMatcher(
977                   checkpoint_uri, HttpRequest::Method::kGet, _, "")))
978       .WillOnce(Return(FakeHttpResponse(503, {}, "")));
979 
980   {
981     InSequence seq;
982     EXPECT_CALL(
983         mock_log_manager_,
984         LogDiag(
985             ProdDiagCode::HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_USES_HTTP));
986     EXPECT_CALL(
987         mock_log_manager_,
988         LogDiag(
989             ProdDiagCode::
990                 HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_FAILED));
991   }
992 
993   // Issue the eligibility eval checkin.
994   auto eligibility_checkin_result = federated_protocol_->EligibilityEvalCheckin(
995       mock_eet_received_callback_.AsStdFunction());
996 
997   EXPECT_THAT(eligibility_checkin_result.status(), IsCode(UNAVAILABLE));
998   EXPECT_THAT(eligibility_checkin_result.status().message(),
999               HasSubstr("checkpoint fetch failed"));
1000   EXPECT_THAT(eligibility_checkin_result.status().message(), HasSubstr("503"));
1001   // The EligibilityEvalCheckin call is expected to return the rejected error
1002   // retry window, since 503 maps to a transient error.
1003   ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1004 }
1005 
1006 // Tests that the protocol correctly sanitizes any invalid values it may have
1007 // received from the server.
TEST_P(GrpcFederatedProtocolTest,TestNegativeMinMaxRetryDelayValueSanitization)1008 TEST_P(GrpcFederatedProtocolTest,
1009        TestNegativeMinMaxRetryDelayValueSanitization) {
1010   google::internal::federatedml::v2::RetryWindow retry_window;
1011   retry_window.mutable_delay_min()->set_seconds(-1);
1012   retry_window.mutable_delay_max()->set_seconds(-2);
1013 
1014   // The above retry window's negative min/max values should be clamped to 0.
1015   google::internal::federatedml::v2::RetryWindow expected_retry_window;
1016   expected_retry_window.mutable_delay_min()->set_seconds(0);
1017   expected_retry_window.mutable_delay_max()->set_seconds(0);
1018 
1019   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin(
1020       /* eligibility_eval_enabled=*/true, retry_window, retry_window));
1021   const RetryWindow& actual_retry_window =
1022       federated_protocol_->GetLatestRetryWindow();
1023   // The above retry window's invalid max value should be clamped to the min
1024   // value (minus some errors introduced by the inaccuracy of double
1025   // multiplication).
1026   EXPECT_THAT(actual_retry_window.delay_min().seconds() +
1027                   actual_retry_window.delay_min().nanos() / 1000000000.0,
1028               DoubleEq(0));
1029   EXPECT_THAT(actual_retry_window.delay_max().seconds() +
1030                   actual_retry_window.delay_max().nanos() / 1000000000.0,
1031               DoubleEq(0));
1032 }
1033 
1034 // Tests that the protocol correctly sanitizes any invalid values it may have
1035 // received from the server.
TEST_P(GrpcFederatedProtocolTest,TestInvalidMaxRetryDelayValueSanitization)1036 TEST_P(GrpcFederatedProtocolTest, TestInvalidMaxRetryDelayValueSanitization) {
1037   google::internal::federatedml::v2::RetryWindow retry_window;
1038   retry_window.mutable_delay_min()->set_seconds(1234);
1039   retry_window.mutable_delay_max()->set_seconds(1233);  // less than delay_min
1040 
1041   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin(
1042       /* eligibility_eval_enabled=*/true, retry_window, retry_window));
1043   const RetryWindow& actual_retry_window =
1044       federated_protocol_->GetLatestRetryWindow();
1045   // The above retry window's invalid max value should be clamped to the min
1046   // value (minus some errors introduced by the inaccuracy of double
1047   // multiplication). Note that DoubleEq enforces too precise of bounds, so we
1048   // use DoubleNear instead.
1049   EXPECT_THAT(actual_retry_window.delay_min().seconds() +
1050                   actual_retry_window.delay_min().nanos() / 1000000000.0,
1051               DoubleNear(1234.0, 0.02));
1052   EXPECT_THAT(actual_retry_window.delay_max().seconds() +
1053                   actual_retry_window.delay_max().nanos() / 1000000000.0,
1054               DoubleNear(1234.0, 0.02));
1055 }
1056 
TEST_P(GrpcFederatedProtocolDeathTest,TestCheckinMissingTaskEligibilityInfo)1057 TEST_P(GrpcFederatedProtocolDeathTest, TestCheckinMissingTaskEligibilityInfo) {
1058   // Issue an eligibility eval checkin first.
1059   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1060 
1061   // A Checkin(...) request with a missing TaskEligibilityInfo should now fail,
1062   // as the protocol requires us to provide one based on the plan includes in
1063   // the eligibility eval checkin response payload.
1064   ASSERT_DEATH(
1065       {
1066         auto unused = federated_protocol_->Checkin(
1067             std::nullopt, mock_task_received_callback_.AsStdFunction());
1068       },
1069       _);
1070 }
1071 
TEST_P(GrpcFederatedProtocolTest,TestCheckinSendFailsTransientError)1072 TEST_P(GrpcFederatedProtocolTest, TestCheckinSendFailsTransientError) {
1073   // Issue an eligibility eval checkin first.
1074   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1075 
1076   // Make the gRPC stream return an UNAVAILABLE error when the Checkin(...) code
1077   // tries to send its first message. This should result in the error being
1078   // returned as the result.
1079   EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
1080       .WillOnce(Return(absl::UnavailableError("foo")));
1081 
1082   auto checkin_result = federated_protocol_->Checkin(
1083       GetFakeTaskEligibilityInfo(),
1084       mock_task_received_callback_.AsStdFunction());
1085   EXPECT_THAT(checkin_result.status(), IsCode(UNAVAILABLE));
1086   EXPECT_THAT(checkin_result.status().message(), "foo");
1087   // RetryWindows were already received from the server during the eligibility
1088   // eval checkin, so we expect to get a 'rejected' retry window.
1089   ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1090 }
1091 
TEST_P(GrpcFederatedProtocolTest,TestCheckinSendFailsPermanentError)1092 TEST_P(GrpcFederatedProtocolTest, TestCheckinSendFailsPermanentError) {
1093   // Issue an eligibility eval checkin first.
1094   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1095 
1096   // Make the gRPC stream return an NOT_FOUND error when the Checkin(...) code
1097   // tries to send its first message. This should result in the error being
1098   // returned as the result.
1099   EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
1100       .WillOnce(Return(absl::NotFoundError("foo")));
1101 
1102   auto checkin_result = federated_protocol_->Checkin(
1103       GetFakeTaskEligibilityInfo(),
1104       mock_task_received_callback_.AsStdFunction());
1105   EXPECT_THAT(checkin_result.status(), IsCode(NOT_FOUND));
1106   EXPECT_THAT(checkin_result.status().message(), "foo");
1107   // Even though RetryWindows were already received from the server during the
1108   // eligibility eval checkin, we expect a RetryWindow generated based on the
1109   // *permanent* errors retry delay flag, since NOT_FOUND is marked as a
1110   // permanent error in the flags, and permanent errors should always result in
1111   // permanent error windows (regardless of whether a CheckinRequestAck was
1112   // already received).
1113   ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
1114 }
1115 
1116 // Tests the case where the blocking Send() call in Checkin is interrupted.
TEST_P(GrpcFederatedProtocolTest,TestCheckinSendInterrupted)1117 TEST_P(GrpcFederatedProtocolTest, TestCheckinSendInterrupted) {
1118   // Issue an eligibility eval checkin first.
1119   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1120 
1121   absl::BlockingCounter counter_should_abort(1);
1122 
1123   // Make Send() block until the counter is decremented.
1124   EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
1125       .WillOnce([&counter_should_abort](ClientStreamMessage* ignored) {
1126         counter_should_abort.Wait();
1127         return absl::OkStatus();
1128       });
1129   // Make should_abort return false for the first two calls, and then make it
1130   // decrement the counter and return true, triggering an abort sequence and
1131   // unblocking the Send() call we caused to block above.
1132   EXPECT_CALL(mock_should_abort_, Call())
1133       .WillOnce(Return(false))
1134       .WillOnce(Return(false))
1135       .WillRepeatedly([&counter_should_abort] {
1136         counter_should_abort.DecrementCount();
1137         return true;
1138       });
1139   // In addition to the Close() call we expect in the test fixture above, expect
1140   // an additional one (the one that induced the abort).
1141   EXPECT_CALL(*mock_grpc_bidi_stream_, Close()).Times(1).RetiresOnSaturation();
1142   EXPECT_CALL(mock_log_manager_,
1143               LogDiag(ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_GRPC));
1144 
1145   auto checkin_result = federated_protocol_->Checkin(
1146       GetFakeTaskEligibilityInfo(),
1147       mock_task_received_callback_.AsStdFunction());
1148   EXPECT_THAT(checkin_result.status(), IsCode(CANCELLED));
1149   // RetryWindows were already received from the server during the eligibility
1150   // eval checkin, so we expect to get a 'rejected' retry window.
1151   ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1152 }
1153 
TEST_P(GrpcFederatedProtocolTest,TestCheckinRejectionWithTaskEligibilityInfo)1154 TEST_P(GrpcFederatedProtocolTest, TestCheckinRejectionWithTaskEligibilityInfo) {
1155   // Issue an eligibility eval checkin first.
1156   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1157 
1158   // Expect a checkin request for the next call to Checkin(...).
1159   EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
1160       .WillOnce(Return(absl::OkStatus()));
1161   EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
1162       .WillOnce(DoAll(SetArgPointee<0>(GetFakeRejectedCheckinResponse()),
1163                       Return(absl::OkStatus())));
1164 
1165   // The 'task received' callback should not be invoked since no task was given
1166   // to the client.
1167   EXPECT_CALL(mock_task_received_callback_, Call(_)).Times(0);
1168 
1169   // Issue the regular checkin.
1170   auto checkin_result = federated_protocol_->Checkin(
1171       GetFakeTaskEligibilityInfo(),
1172       mock_task_received_callback_.AsStdFunction());
1173 
1174   ASSERT_OK(checkin_result.status());
1175   EXPECT_THAT(*checkin_result, VariantWith<FederatedProtocol::Rejection>(_));
1176   ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1177 }
1178 
1179 // Tests whether we can issue a Checkin() request correctly without passing a
1180 // TaskEligibilityInfo, in the case that the eligibility eval checkin didn't
1181 // return any eligibility eval task to run.
TEST_P(GrpcFederatedProtocolTest,TestCheckinRejectionWithoutTaskEligibilityInfo)1182 TEST_P(GrpcFederatedProtocolTest,
1183        TestCheckinRejectionWithoutTaskEligibilityInfo) {
1184   // Issue an eligibility eval checkin first.
1185   ASSERT_OK(
1186       RunSuccessfulEligibilityEvalCheckin(/*eligibility_eval_enabled=*/false));
1187 
1188   // Expect a checkin request for the next call to Checkin(...).
1189   EXPECT_CALL(*mock_grpc_bidi_stream_,
1190               Send(Pointee(EqualsProto(GetExpectedCheckinRequest(
1191                   /*task_eligibility_info=*/std::nullopt,
1192                   enable_http_resource_support_)))))
1193       .WillOnce(Return(absl::OkStatus()));
1194   EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
1195       .WillOnce(DoAll(SetArgPointee<0>(GetFakeRejectedCheckinResponse()),
1196                       Return(absl::OkStatus())));
1197 
1198   // The 'task received' callback should not be invoked since no task was given
1199   // to the client.
1200   EXPECT_CALL(mock_task_received_callback_, Call(_)).Times(0);
1201 
1202   // Issue the regular checkin, without a TaskEligibilityInfo (since we didn't
1203   // receive an eligibility eval task to run during eligibility eval checkin).
1204   auto checkin_result = federated_protocol_->Checkin(
1205       std::nullopt, mock_task_received_callback_.AsStdFunction());
1206 
1207   ASSERT_OK(checkin_result.status());
1208   EXPECT_THAT(*checkin_result, VariantWith<FederatedProtocol::Rejection>(_));
1209   ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1210 }
1211 
TEST_P(GrpcFederatedProtocolTest,TestCheckinAccept)1212 TEST_P(GrpcFederatedProtocolTest, TestCheckinAccept) {
1213   // Issue an eligibility eval checkin first.
1214   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1215 
1216   // Once the eligibility eval checkin has succeeded, let's fake some network
1217   // stats data so that we can verify that it is logged correctly.
1218   int64_t chunking_layer_bytes_downloaded = 555;
1219   int64_t chunking_layer_bytes_uploaded = 666;
1220   EXPECT_CALL(*mock_grpc_bidi_stream_, ChunkingLayerBytesReceived())
1221       .WillRepeatedly(Return(chunking_layer_bytes_downloaded));
1222   EXPECT_CALL(*mock_grpc_bidi_stream_, ChunkingLayerBytesSent())
1223       .WillRepeatedly(Return(chunking_layer_bytes_uploaded));
1224 
1225   // Note that in this particular test we check that the CheckinRequest is as
1226   // expected (in all prior tests we just use the '_' matcher, because the
1227   // request isn't really relevant to the test).
1228   TaskEligibilityInfo expected_eligibility_info = GetFakeTaskEligibilityInfo();
1229   EXPECT_CALL(*mock_grpc_bidi_stream_,
1230               Send(Pointee(EqualsProto(GetExpectedCheckinRequest(
1231                   expected_eligibility_info, enable_http_resource_support_)))))
1232       .WillOnce(Return(absl::OkStatus()));
1233 
1234   std::string expected_plan = kPlan;
1235   std::string expected_checkpoint = kInitCheckpoint;
1236   EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
1237       .WillOnce(DoAll(SetArgPointee<0>(GetFakeAcceptedCheckinResponse(
1238                           expected_plan, expected_checkpoint,
1239                           kFederatedSelectUriTemplate, kExecutionPhaseId,
1240                           /* use_secure_aggregation=*/true)),
1241                       Return(absl::OkStatus())));
1242 
1243   // The 'task received' callback should be called even when the resources were
1244   // available inline.
1245   EXPECT_CALL(
1246       mock_task_received_callback_,
1247       Call(FieldsAre(
1248           FieldsAre("", ""), kFederatedSelectUriTemplate, kExecutionPhaseId,
1249           Optional(AllOf(
1250               Field(&FederatedProtocol::SecAggInfo::expected_number_of_clients,
1251                     kSecAggExpectedNumberOfClients),
1252               Field(&FederatedProtocol::SecAggInfo::
1253                         minimum_clients_in_server_visible_aggregate,
1254                     kSecAggMinClientsInServerVisibleAggregate))))));
1255 
1256   // Issue the regular checkin.
1257   auto checkin_result = federated_protocol_->Checkin(
1258       expected_eligibility_info, mock_task_received_callback_.AsStdFunction());
1259 
1260   ASSERT_OK(checkin_result.status());
1261   // If HTTP support is enabled then the checkpoint data gets returned in the
1262   // shape of an absl::Cord (rather than an std::string), regardless of whether
1263   // it was actually downloaded via HTTP.
1264   if (enable_http_resource_support_) {
1265     EXPECT_THAT(
1266         *checkin_result,
1267         VariantWith<FederatedProtocol::TaskAssignment>(FieldsAre(
1268             FieldsAre(absl::Cord(expected_plan),
1269                       absl::Cord(expected_checkpoint)),
1270             kFederatedSelectUriTemplate, kExecutionPhaseId,
1271             Optional(AllOf(
1272                 Field(
1273                     &FederatedProtocol::SecAggInfo::expected_number_of_clients,
1274                     kSecAggExpectedNumberOfClients),
1275                 Field(&FederatedProtocol::SecAggInfo::
1276                           minimum_clients_in_server_visible_aggregate,
1277                       kSecAggMinClientsInServerVisibleAggregate))))));
1278   } else {
1279     EXPECT_THAT(
1280         *checkin_result,
1281         VariantWith<FederatedProtocol::TaskAssignment>(FieldsAre(
1282             FieldsAre(expected_plan, expected_checkpoint),
1283             kFederatedSelectUriTemplate, kExecutionPhaseId,
1284             Optional(AllOf(
1285                 Field(
1286                     &FederatedProtocol::SecAggInfo::expected_number_of_clients,
1287                     kSecAggExpectedNumberOfClients),
1288                 Field(&FederatedProtocol::SecAggInfo::
1289                           minimum_clients_in_server_visible_aggregate,
1290                       kSecAggMinClientsInServerVisibleAggregate))))));
1291   }
1292   // The Checkin call is expected to return the accepted retry window from the
1293   // CheckinRequestAck response to the first eligibility eval request.
1294   ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1295 }
1296 
TEST_P(GrpcFederatedProtocolTest,TestCheckinAcceptWithHttpResourcesDownloaded)1297 TEST_P(GrpcFederatedProtocolTest,
1298        TestCheckinAcceptWithHttpResourcesDownloaded) {
1299   if (!enable_http_resource_support_) {
1300     GTEST_SKIP() << "This test only applies the HTTP task resources feature "
1301                     "is enabled";
1302     return;
1303   }
1304   // Issue an eligibility eval checkin first.
1305   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1306 
1307   // Once the eligibility eval checkin has succeeded, let's fake some network
1308   // stats data so that we can verify that it is logged correctly.
1309   int64_t chunking_layer_bytes_downloaded = 555;
1310   int64_t chunking_layer_bytes_uploaded = 666;
1311   EXPECT_CALL(*mock_grpc_bidi_stream_, ChunkingLayerBytesReceived())
1312       .WillRepeatedly(Return(chunking_layer_bytes_downloaded));
1313   EXPECT_CALL(*mock_grpc_bidi_stream_, ChunkingLayerBytesSent())
1314       .WillRepeatedly(Return(chunking_layer_bytes_uploaded));
1315 
1316   // Note that in this particular test we check that the CheckinRequest is as
1317   // expected (in all prior tests we just use the '_' matcher, because the
1318   // request isn't really relevant to the test).
1319   TaskEligibilityInfo expected_eligibility_info = GetFakeTaskEligibilityInfo();
1320   EXPECT_CALL(
1321       *mock_grpc_bidi_stream_,
1322       Send(Pointee(EqualsProto(GetExpectedCheckinRequest(
1323           expected_eligibility_info, /*enable_http_resource_support=*/true)))))
1324       .WillOnce(Return(absl::OkStatus()));
1325 
1326   std::string expected_plan = kPlan;
1327   std::string plan_uri = "https://fake.uri/plan";
1328   std::string expected_checkpoint = kInitCheckpoint;
1329   std::string checkpoint_uri = "https://fake.uri/checkpoint";
1330   ServerStreamMessage fake_checkin_response = GetFakeAcceptedCheckinResponse(
1331       /*plan=*/"", /*init_checkpoint=*/"", kFederatedSelectUriTemplate,
1332       kExecutionPhaseId,
1333       /* use_secure_aggregation=*/true);
1334   AcceptanceInfo* acceptance_info =
1335       fake_checkin_response.mutable_checkin_response()
1336           ->mutable_acceptance_info();
1337   acceptance_info->mutable_plan_resource()->set_uri(plan_uri);
1338   acceptance_info->mutable_init_checkpoint_resource()->set_uri(checkpoint_uri);
1339   EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
1340       .WillOnce(DoAll(SetArgPointee<0>(fake_checkin_response),
1341                       Return(absl::OkStatus())));
1342 
1343   {
1344     InSequence seq;
1345     // The 'task received' callback should be called *before* the actual task
1346     // resources are fetched.
1347     EXPECT_CALL(
1348         mock_task_received_callback_,
1349         Call(FieldsAre(
1350             FieldsAre("", ""), kFederatedSelectUriTemplate, kExecutionPhaseId,
1351             Optional(AllOf(
1352                 Field(
1353                     &FederatedProtocol::SecAggInfo::expected_number_of_clients,
1354                     kSecAggExpectedNumberOfClients),
1355                 Field(&FederatedProtocol::SecAggInfo::
1356                           minimum_clients_in_server_visible_aggregate,
1357                       kSecAggMinClientsInServerVisibleAggregate))))));
1358 
1359     EXPECT_CALL(mock_http_client_,
1360                 PerformSingleRequest(SimpleHttpRequestMatcher(
1361                     plan_uri, HttpRequest::Method::kGet, _, "")))
1362         .WillOnce(Return(FakeHttpResponse(200, {}, expected_plan)));
1363 
1364     EXPECT_CALL(mock_http_client_,
1365                 PerformSingleRequest(SimpleHttpRequestMatcher(
1366                     checkpoint_uri, HttpRequest::Method::kGet, _, "")))
1367         .WillOnce(Return(FakeHttpResponse(200, {}, expected_checkpoint)));
1368   }
1369 
1370   {
1371     InSequence seq;
1372     EXPECT_CALL(
1373         mock_log_manager_,
1374         LogDiag(
1375             ProdDiagCode::HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_USES_HTTP));
1376     EXPECT_CALL(
1377         mock_log_manager_,
1378         LogDiag(
1379             ProdDiagCode::
1380                 HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_SUCCEEDED));
1381   }
1382 
1383   // Issue the regular checkin.
1384   auto checkin_result = federated_protocol_->Checkin(
1385       expected_eligibility_info, mock_task_received_callback_.AsStdFunction());
1386 
1387   ASSERT_OK(checkin_result.status());
1388   EXPECT_THAT(
1389       *checkin_result,
1390       VariantWith<FederatedProtocol::TaskAssignment>(FieldsAre(
1391           FieldsAre(absl::Cord(expected_plan), absl::Cord(expected_checkpoint)),
1392           kFederatedSelectUriTemplate, kExecutionPhaseId,
1393           Optional(AllOf(
1394               Field(&FederatedProtocol::SecAggInfo::expected_number_of_clients,
1395                     kSecAggExpectedNumberOfClients),
1396               Field(&FederatedProtocol::SecAggInfo::
1397                         minimum_clients_in_server_visible_aggregate,
1398                     kSecAggMinClientsInServerVisibleAggregate))))));
1399   // The Checkin call is expected to return the accepted retry window from the
1400   // CheckinRequestAck response to the first eligibility eval request.
1401   ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1402 }
1403 
TEST_P(GrpcFederatedProtocolTest,TestCheckinAcceptWithHttpResourcePlanDataFetchFailed)1404 TEST_P(GrpcFederatedProtocolTest,
1405        TestCheckinAcceptWithHttpResourcePlanDataFetchFailed) {
1406   if (!enable_http_resource_support_) {
1407     GTEST_SKIP() << "This test only applies the HTTP task resources feature "
1408                     "is enabled";
1409     return;
1410   }
1411   // Issue an eligibility eval checkin first.
1412   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1413 
1414   // Note that in this particular test we check that the CheckinRequest is as
1415   // expected (in all prior tests we just use the '_' matcher, because the
1416   // request isn't really relevant to the test).
1417   TaskEligibilityInfo expected_eligibility_info = GetFakeTaskEligibilityInfo();
1418   EXPECT_CALL(
1419       *mock_grpc_bidi_stream_,
1420       Send(Pointee(EqualsProto(GetExpectedCheckinRequest(
1421           expected_eligibility_info, /*enable_http_resource_support=*/true)))))
1422       .WillOnce(Return(absl::OkStatus()));
1423 
1424   std::string expected_plan = kPlan;
1425   std::string plan_uri = "https://fake.uri/plan";
1426   std::string expected_checkpoint = kInitCheckpoint;
1427   ServerStreamMessage fake_checkin_response = GetFakeAcceptedCheckinResponse(
1428       /*plan=*/"", expected_checkpoint, kFederatedSelectUriTemplate,
1429       kExecutionPhaseId,
1430       /* use_secure_aggregation=*/true);
1431   AcceptanceInfo* acceptance_info =
1432       fake_checkin_response.mutable_checkin_response()
1433           ->mutable_acceptance_info();
1434   acceptance_info->mutable_plan_resource()->set_uri(plan_uri);
1435   EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
1436       .WillOnce(DoAll(SetArgPointee<0>(fake_checkin_response),
1437                       Return(absl::OkStatus())));
1438 
1439   // Mock a failed plan fetch.
1440   EXPECT_CALL(mock_http_client_,
1441               PerformSingleRequest(SimpleHttpRequestMatcher(
1442                   plan_uri, HttpRequest::Method::kGet, _, "")))
1443       .WillOnce(Return(FakeHttpResponse(404, {}, "")));
1444 
1445   {
1446     InSequence seq;
1447     EXPECT_CALL(
1448         mock_log_manager_,
1449         LogDiag(
1450             ProdDiagCode::HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_USES_HTTP));
1451     EXPECT_CALL(
1452         mock_log_manager_,
1453         LogDiag(
1454             ProdDiagCode::
1455                 HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_FAILED));
1456   }
1457 
1458   // Issue the regular checkin.
1459   auto checkin_result = federated_protocol_->Checkin(
1460       expected_eligibility_info, mock_task_received_callback_.AsStdFunction());
1461 
1462   EXPECT_THAT(checkin_result.status(), IsCode(NOT_FOUND));
1463   EXPECT_THAT(checkin_result.status().message(),
1464               HasSubstr("plan fetch failed"));
1465   EXPECT_THAT(checkin_result.status().message(), HasSubstr("404"));
1466   // The Checkin call is expected to return the permanent error retry window,
1467   // since 404 maps to a permanent error.
1468   ExpectPermanentErrorRetryWindow(federated_protocol_->GetLatestRetryWindow());
1469 }
1470 
TEST_P(GrpcFederatedProtocolTest,TestCheckinAcceptWithHttpResourceCheckpointDataFetchFailed)1471 TEST_P(GrpcFederatedProtocolTest,
1472        TestCheckinAcceptWithHttpResourceCheckpointDataFetchFailed) {
1473   if (!enable_http_resource_support_) {
1474     GTEST_SKIP() << "This test only applies the HTTP task resources feature "
1475                     "is enabled";
1476     return;
1477   }
1478   // Issue an eligibility eval checkin first.
1479   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1480 
1481   // Note that in this particular test we check that the CheckinRequest is as
1482   // expected (in all prior tests we just use the '_' matcher, because the
1483   // request isn't really relevant to the test).
1484   TaskEligibilityInfo expected_eligibility_info = GetFakeTaskEligibilityInfo();
1485   EXPECT_CALL(
1486       *mock_grpc_bidi_stream_,
1487       Send(Pointee(EqualsProto(GetExpectedCheckinRequest(
1488           expected_eligibility_info, /*enable_http_resource_support=*/true)))))
1489       .WillOnce(Return(absl::OkStatus()));
1490 
1491   std::string expected_plan = kPlan;
1492   std::string expected_checkpoint = kInitCheckpoint;
1493   std::string checkpoint_uri = "https://fake.uri/checkpoint";
1494   ServerStreamMessage fake_checkin_response = GetFakeAcceptedCheckinResponse(
1495       expected_plan, /*init_checkpoint=*/"", kFederatedSelectUriTemplate,
1496       kExecutionPhaseId,
1497       /* use_secure_aggregation=*/true);
1498   AcceptanceInfo* acceptance_info =
1499       fake_checkin_response.mutable_checkin_response()
1500           ->mutable_acceptance_info();
1501   acceptance_info->mutable_init_checkpoint_resource()->set_uri(checkpoint_uri);
1502   EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
1503       .WillOnce(DoAll(SetArgPointee<0>(fake_checkin_response),
1504                       Return(absl::OkStatus())));
1505 
1506   // Mock a failed checkpoint fetch.
1507   EXPECT_CALL(mock_http_client_,
1508               PerformSingleRequest(SimpleHttpRequestMatcher(
1509                   checkpoint_uri, HttpRequest::Method::kGet, _, "")))
1510       .WillOnce(Return(FakeHttpResponse(503, {}, "")));
1511 
1512   {
1513     InSequence seq;
1514     EXPECT_CALL(
1515         mock_log_manager_,
1516         LogDiag(
1517             ProdDiagCode::HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_USES_HTTP));
1518     EXPECT_CALL(
1519         mock_log_manager_,
1520         LogDiag(
1521             ProdDiagCode::
1522                 HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_FAILED));
1523   }
1524 
1525   // Issue the regular checkin.
1526   auto checkin_result = federated_protocol_->Checkin(
1527       expected_eligibility_info, mock_task_received_callback_.AsStdFunction());
1528 
1529   EXPECT_THAT(checkin_result.status(), IsCode(UNAVAILABLE));
1530   EXPECT_THAT(checkin_result.status().message(),
1531               HasSubstr("checkpoint fetch failed"));
1532   EXPECT_THAT(checkin_result.status().message(), HasSubstr("503"));
1533   // The Checkin call is expected to return the rejected retry window from the
1534   // response to the first eligibility eval request.
1535   ExpectRejectedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1536 }
1537 
TEST_P(GrpcFederatedProtocolTest,TestCheckinAcceptNonSecAgg)1538 TEST_P(GrpcFederatedProtocolTest, TestCheckinAcceptNonSecAgg) {
1539   // Issue an eligibility eval checkin first, followed by a successful checkin
1540   // returning a task that doesn't use SecAgg.
1541   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1542   auto checkin_result = RunSuccessfulCheckin(/*use_secure_aggregation=*/false);
1543   ASSERT_OK(checkin_result.status());
1544   // If HTTP support is enabled then the checkpoint data gets returned in the
1545   // shape of an absl::Cord (rather than an std::string), regardless of whether
1546   // it was actually downloaded via HTTP.
1547   if (enable_http_resource_support_) {
1548     EXPECT_THAT(*checkin_result,
1549                 VariantWith<FederatedProtocol::TaskAssignment>(FieldsAre(
1550                     FieldsAre(absl::Cord(kPlan), absl::Cord(kInitCheckpoint)),
1551                     kFederatedSelectUriTemplate, kExecutionPhaseId,
1552                     // There should be no SecAggInfo in the result.
1553                     Eq(std::nullopt))));
1554   } else {
1555     EXPECT_THAT(*checkin_result,
1556                 VariantWith<FederatedProtocol::TaskAssignment>(
1557                     FieldsAre(FieldsAre(kPlan, kInitCheckpoint),
1558                               kFederatedSelectUriTemplate, kExecutionPhaseId,
1559                               // There should be no SecAggInfo in the result.
1560                               Eq(std::nullopt))));
1561   }
1562 }
1563 
TEST_P(GrpcFederatedProtocolTest,TestReportWithSecAgg)1564 TEST_P(GrpcFederatedProtocolTest, TestReportWithSecAgg) {
1565   // Issue an eligibility eval checkin first, followed by a successful checkin.
1566   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1567   ASSERT_OK(RunSuccessfulCheckin(/*use_secure_aggregation=*/true));
1568   // Create a SecAgg like Checkpoint - a combination of a TF checkpoint, and
1569   // one or more SecAgg quantized aggregands.
1570   ComputationResults results;
1571   results.emplace("tensorflow_checkpoint", "");
1572   results.emplace("some_tensor", QuantizedTensor());
1573 
1574   mock_secagg_runner_ = new StrictMock<MockSecAggRunner>();
1575   EXPECT_CALL(*mock_secagg_runner_factory_,
1576               CreateSecAggRunner(_, _, _, _, _, kSecAggExpectedNumberOfClients,
1577                                  kSecAggMinSurvivingClientsForReconstruction))
1578       .WillOnce(Return(ByMove(absl::WrapUnique(mock_secagg_runner_))));
1579   EXPECT_CALL(
1580       *mock_secagg_runner_,
1581       Run(UnorderedElementsAre(
1582           Pair("tensorflow_checkpoint", VariantWith<TFCheckpoint>(IsEmpty())),
1583           Pair("some_tensor", VariantWith<QuantizedTensor>(
1584                                   FieldsAre(IsEmpty(), 0, IsEmpty()))))))
1585       .WillOnce(Return(absl::OkStatus()));
1586   EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
1587           .WillOnce(
1588               DoAll(SetArgPointee<0>(GetFakeReportResponse()),
1589                     Return(absl::OkStatus())));
1590   EXPECT_OK(federated_protocol_->ReportCompleted(
1591       std::move(results), absl::ZeroDuration(), std::nullopt));
1592 }
1593 
TEST_P(GrpcFederatedProtocolTest,TestReportWithSecAggWithoutTFCheckpoint)1594 TEST_P(GrpcFederatedProtocolTest, TestReportWithSecAggWithoutTFCheckpoint) {
1595   // Issue an eligibility eval checkin first, followed by a successful checkin.
1596   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1597   ASSERT_OK(RunSuccessfulCheckin(/*use_secure_aggregation=*/true));
1598 
1599   ComputationResults results;
1600   results.emplace("some_tensor", QuantizedTensor());
1601 
1602   mock_secagg_runner_ = new StrictMock<MockSecAggRunner>();
1603   EXPECT_CALL(*mock_secagg_runner_factory_,
1604               CreateSecAggRunner(_, _, _, _, _, kSecAggExpectedNumberOfClients,
1605                                  kSecAggMinSurvivingClientsForReconstruction))
1606       .WillOnce(Return(ByMove(absl::WrapUnique(mock_secagg_runner_))));
1607   EXPECT_CALL(*mock_secagg_runner_,
1608               Run(UnorderedElementsAre(
1609                   Pair("some_tensor", VariantWith<QuantizedTensor>(FieldsAre(
1610                                           IsEmpty(), 0, IsEmpty()))))))
1611       .WillOnce(Return(absl::OkStatus()));
1612   EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
1613           .WillOnce(
1614               DoAll(SetArgPointee<0>(GetFakeReportResponse()),
1615                     Return(absl::OkStatus())));
1616   EXPECT_OK(federated_protocol_->ReportCompleted(
1617       std::move(results), absl::ZeroDuration(), std::nullopt));
1618 }
1619 
1620 // This function tests the Report(...) method's Send code path, ensuring the
1621 // right events are logged / and the right data is transmitted to the server.
TEST_P(GrpcFederatedProtocolTest,TestReportSendFails)1622 TEST_P(GrpcFederatedProtocolTest, TestReportSendFails) {
1623   // Issue an eligibility eval checkin first, followed by a successful checkin.
1624   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1625   ASSERT_OK(RunSuccessfulCheckin(/*use_secure_aggregation=*/false));
1626 
1627   // 1. Create input for the Report function.
1628   std::string checkpoint_str;
1629   const size_t kTFCheckpointSize = 32;
1630   checkpoint_str.resize(kTFCheckpointSize, 'X');
1631   ComputationResults results;
1632   results.emplace("tensorflow_checkpoint", checkpoint_str);
1633 
1634   absl::Duration plan_duration = absl::Milliseconds(1337);
1635 
1636   // 2. The expected message sent to the server by the ReportCompleted()
1637   // function, as text proto.
1638   ClientStreamMessage expected_client_stream_message;
1639   ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(
1640       absl::StrCat(
1641           "report_request {", "  population_name: \"", kPopulationName, "\"",
1642           "  execution_phase_id: \"", kExecutionPhaseId, "\"", "  report {",
1643           "    update_checkpoint: \"XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX\"",
1644           "    serialized_train_event {", "[type.googleapis.com/",
1645           "google.internal.federatedml.v2.ClientExecutionStats] {",
1646           "        duration { seconds: 1 nanos: 337000000 }", "      }",
1647           "    }", "  }", "}"),
1648       &expected_client_stream_message));
1649 
1650   // 3. Set up mocks.
1651   EXPECT_CALL(*mock_grpc_bidi_stream_,
1652               Send(Pointee(EqualsProto(expected_client_stream_message))))
1653       .WillOnce(Return(absl::AbortedError("foo")));
1654 
1655   // 4. Test that ReportCompleted() sends the expected message.
1656   auto report_result = federated_protocol_->ReportCompleted(
1657       std::move(results), plan_duration, std::nullopt);
1658   EXPECT_THAT(report_result, IsCode(ABORTED));
1659   EXPECT_THAT(report_result.message(), HasSubstr("foo"));
1660 
1661   // If we made it to the Report protocol phase, then the client must've been
1662   // accepted during the Checkin phase first, and so we should receive the
1663   // "accepted" RetryWindow.
1664   ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1665 }
1666 
1667 // This function tests the happy path of ReportCompleted() - results get
1668 // reported, server replies with a RetryWindow.
TEST_P(GrpcFederatedProtocolTest,TestPublishReportSuccess)1669 TEST_P(GrpcFederatedProtocolTest, TestPublishReportSuccess) {
1670   // Issue an eligibility eval checkin first, followed by a successful checkin.
1671   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1672   ASSERT_OK(RunSuccessfulCheckin(/*use_secure_aggregation=*/false));
1673 
1674   // 1. Create input for the Report function.
1675   ComputationResults results;
1676   results.emplace("tensorflow_checkpoint", "");
1677 
1678   // 2. Set up mocks.
1679   EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
1680       .WillOnce(Return(absl::OkStatus()));
1681   ServerStreamMessage response_message;
1682   response_message.mutable_report_response();
1683   EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
1684       .WillOnce(
1685           DoAll(SetArgPointee<0>(response_message), Return(absl::OkStatus())));
1686 
1687   // 3. Test that ReportCompleted() sends the expected message.
1688   auto report_result = federated_protocol_->ReportCompleted(
1689       std::move(results), absl::ZeroDuration(), std::nullopt);
1690   EXPECT_OK(report_result);
1691 
1692   // If we made it to the Report protocol phase, then the client must've been
1693   // accepted during the Checkin phase first, and so we should receive the
1694   // "accepted" RetryWindow.
1695   ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1696 }
1697 
1698 // This function tests the Send code path when PhaseOutcome indicates an
1699 // error. / In that case, no checkpoint, and only the duration stat, should be
1700 // uploaded.
TEST_P(GrpcFederatedProtocolTest,TestPublishReportNotCompleteSendFails)1701 TEST_P(GrpcFederatedProtocolTest, TestPublishReportNotCompleteSendFails) {
1702   // Issue an eligibility eval checkin first, followed by a successful checkin.
1703   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1704   ASSERT_OK(RunSuccessfulCheckin(/*use_secure_aggregation=*/false));
1705 
1706   // 1. Create input for the Report function.
1707   absl::Duration plan_duration = absl::Milliseconds(1337);
1708 
1709   // 2. The expected message sent to the server by the ReportNotCompleted()
1710   // function, as text proto.
1711   ClientStreamMessage expected_client_stream_message;
1712   ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(
1713       absl::StrCat("report_request {", "  population_name: \"", kPopulationName,
1714                    "\"", "  execution_phase_id: \"", kExecutionPhaseId, "\"",
1715                    "  report {", "    serialized_train_event {",
1716                    "[type.googleapis.com/",
1717                    "google.internal.federatedml.v2.ClientExecutionStats] {",
1718                    "        duration { seconds: 1 nanos: 337000000 }",
1719                    "      }", "    }", "    status_code: INTERNAL", "  }", "}"),
1720       &expected_client_stream_message));
1721 
1722   // 3. Set up mocks.
1723   EXPECT_CALL(*mock_grpc_bidi_stream_,
1724               Send(Pointee(EqualsProto(expected_client_stream_message))))
1725       .WillOnce(Return(absl::AbortedError("foo")));
1726 
1727   // 4. Test that ReportNotCompleted() sends the expected message.
1728   auto report_result = federated_protocol_->ReportNotCompleted(
1729       engine::PhaseOutcome::ERROR, plan_duration, std::nullopt);
1730   EXPECT_THAT(report_result, IsCode(ABORTED));
1731   EXPECT_THAT(report_result.message(), HasSubstr("foo"));
1732 
1733   // If we made it to the Report protocol phase, then the client must've been
1734   // accepted during the Checkin phase first, and so we should receive the
1735   // "accepted" RetryWindow.
1736   ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1737 }
1738 
1739 // This function tests the happy path of ReportCompleted() - results get
1740 // reported, server replies with a RetryWindow.
TEST_P(GrpcFederatedProtocolTest,TestPublishReportSuccessCommitsToOpstats)1741 TEST_P(GrpcFederatedProtocolTest, TestPublishReportSuccessCommitsToOpstats) {
1742   // Issue an eligibility eval checkin first, followed by a successful checkin.
1743   ASSERT_OK(RunSuccessfulEligibilityEvalCheckin());
1744   ASSERT_OK(RunSuccessfulCheckin(/*use_secure_aggregation=*/false));
1745 
1746   // 1. Create input for the Report function.
1747   ComputationResults results;
1748   results.emplace("tensorflow_checkpoint", "");
1749 
1750   // 2. Set up mocks.
1751   EXPECT_CALL(*mock_grpc_bidi_stream_, Send(_))
1752       .WillOnce(Return(absl::OkStatus()));
1753   ServerStreamMessage response_message;
1754   response_message.mutable_report_response();
1755   EXPECT_CALL(*mock_grpc_bidi_stream_, Receive(_))
1756       .WillOnce(
1757           DoAll(SetArgPointee<0>(response_message), Return(absl::OkStatus())));
1758 
1759   // 3. Test that ReportCompleted() sends the expected message.
1760   auto report_result = federated_protocol_->ReportCompleted(
1761       std::move(results), absl::ZeroDuration(), std::nullopt);
1762   EXPECT_OK(report_result);
1763 
1764   // If we made it to the Report protocol phase, then the client must've been
1765   // accepted during the Checkin phase first, and so we should receive the
1766   // "accepted" RetryWindow.
1767   ExpectAcceptedRetryWindow(federated_protocol_->GetLatestRetryWindow());
1768 }
1769 
1770 }  // anonymous namespace
1771 }  // namespace fcp::client
1772