1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker * Copyright 2020 Google LLC
3*14675a02SAndroid Build Coastguard Worker *
4*14675a02SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker *
8*14675a02SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker *
10*14675a02SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker */
16*14675a02SAndroid Build Coastguard Worker #include "fcp/client/grpc_federated_protocol.h"
17*14675a02SAndroid Build Coastguard Worker
18*14675a02SAndroid Build Coastguard Worker #include <algorithm>
19*14675a02SAndroid Build Coastguard Worker #include <functional>
20*14675a02SAndroid Build Coastguard Worker #include <memory>
21*14675a02SAndroid Build Coastguard Worker #include <optional>
22*14675a02SAndroid Build Coastguard Worker #include <string>
23*14675a02SAndroid Build Coastguard Worker #include <utility>
24*14675a02SAndroid Build Coastguard Worker #include <variant>
25*14675a02SAndroid Build Coastguard Worker
26*14675a02SAndroid Build Coastguard Worker #include "google/protobuf/duration.pb.h"
27*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h"
28*14675a02SAndroid Build Coastguard Worker #include "absl/status/statusor.h"
29*14675a02SAndroid Build Coastguard Worker #include "absl/time/time.h"
30*14675a02SAndroid Build Coastguard Worker #include "absl/types/span.h"
31*14675a02SAndroid Build Coastguard Worker #include "fcp/base/monitoring.h"
32*14675a02SAndroid Build Coastguard Worker #include "fcp/base/time_util.h"
33*14675a02SAndroid Build Coastguard Worker #include "fcp/client/diag_codes.pb.h"
34*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/engine.pb.h"
35*14675a02SAndroid Build Coastguard Worker #include "fcp/client/event_publisher.h"
36*14675a02SAndroid Build Coastguard Worker #include "fcp/client/federated_protocol.h"
37*14675a02SAndroid Build Coastguard Worker #include "fcp/client/federated_protocol_util.h"
38*14675a02SAndroid Build Coastguard Worker #include "fcp/client/fl_runner.pb.h"
39*14675a02SAndroid Build Coastguard Worker #include "fcp/client/flags.h"
40*14675a02SAndroid Build Coastguard Worker #include "fcp/client/grpc_bidi_stream.h"
41*14675a02SAndroid Build Coastguard Worker #include "fcp/client/http/http_client.h"
42*14675a02SAndroid Build Coastguard Worker #include "fcp/client/http/in_memory_request_response.h"
43*14675a02SAndroid Build Coastguard Worker #include "fcp/client/interruptible_runner.h"
44*14675a02SAndroid Build Coastguard Worker #include "fcp/client/log_manager.h"
45*14675a02SAndroid Build Coastguard Worker #include "fcp/client/opstats/opstats_logger.h"
46*14675a02SAndroid Build Coastguard Worker #include "fcp/client/secagg_event_publisher.h"
47*14675a02SAndroid Build Coastguard Worker #include "fcp/client/secagg_runner.h"
48*14675a02SAndroid Build Coastguard Worker #include "fcp/client/stats.h"
49*14675a02SAndroid Build Coastguard Worker #include "fcp/protos/federated_api.pb.h"
50*14675a02SAndroid Build Coastguard Worker #include "fcp/protos/plan.pb.h"
51*14675a02SAndroid Build Coastguard Worker #include "fcp/secagg/client/secagg_client.h"
52*14675a02SAndroid Build Coastguard Worker #include "fcp/secagg/client/send_to_server_interface.h"
53*14675a02SAndroid Build Coastguard Worker #include "fcp/secagg/client/state_transition_listener_interface.h"
54*14675a02SAndroid Build Coastguard Worker #include "fcp/secagg/shared/aes_ctr_prng_factory.h"
55*14675a02SAndroid Build Coastguard Worker #include "fcp/secagg/shared/crypto_rand_prng.h"
56*14675a02SAndroid Build Coastguard Worker #include "fcp/secagg/shared/input_vector_specification.h"
57*14675a02SAndroid Build Coastguard Worker #include "fcp/secagg/shared/math.h"
58*14675a02SAndroid Build Coastguard Worker #include "fcp/secagg/shared/secagg_messages.pb.h"
59*14675a02SAndroid Build Coastguard Worker #include "fcp/secagg/shared/secagg_vector.h"
60*14675a02SAndroid Build Coastguard Worker
61*14675a02SAndroid Build Coastguard Worker namespace fcp {
62*14675a02SAndroid Build Coastguard Worker namespace client {
63*14675a02SAndroid Build Coastguard Worker
64*14675a02SAndroid Build Coastguard Worker using ::fcp::client::http::UriOrInlineData;
65*14675a02SAndroid Build Coastguard Worker using ::fcp::secagg::ClientToServerWrapperMessage;
66*14675a02SAndroid Build Coastguard Worker using ::google::internal::federatedml::v2::CheckinRequest;
67*14675a02SAndroid Build Coastguard Worker using ::google::internal::federatedml::v2::CheckinRequestAck;
68*14675a02SAndroid Build Coastguard Worker using ::google::internal::federatedml::v2::CheckinResponse;
69*14675a02SAndroid Build Coastguard Worker using ::google::internal::federatedml::v2::ClientExecutionStats;
70*14675a02SAndroid Build Coastguard Worker using ::google::internal::federatedml::v2::ClientStreamMessage;
71*14675a02SAndroid Build Coastguard Worker using ::google::internal::federatedml::v2::EligibilityEvalCheckinRequest;
72*14675a02SAndroid Build Coastguard Worker using ::google::internal::federatedml::v2::EligibilityEvalCheckinResponse;
73*14675a02SAndroid Build Coastguard Worker using ::google::internal::federatedml::v2::EligibilityEvalPayload;
74*14675a02SAndroid Build Coastguard Worker using ::google::internal::federatedml::v2::HttpCompressionFormat;
75*14675a02SAndroid Build Coastguard Worker using ::google::internal::federatedml::v2::ProtocolOptionsRequest;
76*14675a02SAndroid Build Coastguard Worker using ::google::internal::federatedml::v2::RetryWindow;
77*14675a02SAndroid Build Coastguard Worker using ::google::internal::federatedml::v2::ServerStreamMessage;
78*14675a02SAndroid Build Coastguard Worker using ::google::internal::federatedml::v2::SideChannelExecutionInfo;
79*14675a02SAndroid Build Coastguard Worker using ::google::internal::federatedml::v2::TaskEligibilityInfo;
80*14675a02SAndroid Build Coastguard Worker
81*14675a02SAndroid Build Coastguard Worker // A note on error handling:
82*14675a02SAndroid Build Coastguard Worker //
83*14675a02SAndroid Build Coastguard Worker // The implementation here makes a distinction between what we call 'transient'
84*14675a02SAndroid Build Coastguard Worker // and 'permanent' errors. While the exact categorization of transient vs.
85*14675a02SAndroid Build Coastguard Worker // permanent errors is defined by a flag, the intent is that transient errors
86*14675a02SAndroid Build Coastguard Worker // are those types of errors that may occur in the regular course of business,
87*14675a02SAndroid Build Coastguard Worker // e.g. due to an interrupted network connection, a load balancer temporarily
88*14675a02SAndroid Build Coastguard Worker // rejecting our request etc. Generally, these are expected to be resolvable by
89*14675a02SAndroid Build Coastguard Worker // merely retrying the request at a slightly later time. Permanent errors are
90*14675a02SAndroid Build Coastguard Worker // intended to be those that are not expected to be resolvable as quickly or by
91*14675a02SAndroid Build Coastguard Worker // merely retrying the request. E.g. if a client checks in to the server with a
92*14675a02SAndroid Build Coastguard Worker // population name that doesn't exist, then the server may return NOT_FOUND, and
93*14675a02SAndroid Build Coastguard Worker // until the server-side configuration is changed, it will continue returning
94*14675a02SAndroid Build Coastguard Worker // such an error. Hence, such errors can warrant a longer retry period (to waste
95*14675a02SAndroid Build Coastguard Worker // less of both the client's and server's resources).
96*14675a02SAndroid Build Coastguard Worker //
97*14675a02SAndroid Build Coastguard Worker // The errors also differ in how they interact with the server-specified retry
98*14675a02SAndroid Build Coastguard Worker // windows that are returned via the CheckinRequestAck message.
99*14675a02SAndroid Build Coastguard Worker // - If a permanent error occurs, then we will always return a retry window
100*14675a02SAndroid Build Coastguard Worker // based on the target 'permanent errors retry period' flag, regardless of
101*14675a02SAndroid Build Coastguard Worker // whether we received a CheckinRequestAck from the server at an earlier time.
102*14675a02SAndroid Build Coastguard Worker // - If a transient error occurs, then we will only return a retry window
103*14675a02SAndroid Build Coastguard Worker // based on the target 'transient errors retry period' flag if the server
104*14675a02SAndroid Build Coastguard Worker // didn't already return a CheckinRequestAck. If it did return such an ack,
105*14675a02SAndroid Build Coastguard Worker // then one of the retry windows in that message will be used instead.
106*14675a02SAndroid Build Coastguard Worker //
107*14675a02SAndroid Build Coastguard Worker // Finally, note that for simplicity's sake we generally check whether a
108*14675a02SAndroid Build Coastguard Worker // permanent error was received at the level of this class's public method,
109*14675a02SAndroid Build Coastguard Worker // rather than deeper down in each of our helper methods that actually call
110*14675a02SAndroid Build Coastguard Worker // directly into the gRPC stack. This keeps our state-managing code simpler, but
111*14675a02SAndroid Build Coastguard Worker // does mean that if any of our helper methods like SendCheckinRequest produce a
112*14675a02SAndroid Build Coastguard Worker // permanent error code locally (i.e. without it being sent by the server), it
113*14675a02SAndroid Build Coastguard Worker // will be treated as if the server sent it and the permanent error retry period
114*14675a02SAndroid Build Coastguard Worker // will be used. We consider this a reasonable tradeoff.
115*14675a02SAndroid Build Coastguard Worker
GrpcFederatedProtocol(EventPublisher * event_publisher,LogManager * log_manager,std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory,const Flags * flags,::fcp::client::http::HttpClient * http_client,const std::string & federated_service_uri,const std::string & api_key,const std::string & test_cert_path,absl::string_view population_name,absl::string_view retry_token,absl::string_view client_version,absl::string_view attestation_measurement,std::function<bool ()> should_abort,const InterruptibleRunner::TimingConfig & timing_config,const int64_t grpc_channel_deadline_seconds,cache::ResourceCache * resource_cache)116*14675a02SAndroid Build Coastguard Worker GrpcFederatedProtocol::GrpcFederatedProtocol(
117*14675a02SAndroid Build Coastguard Worker EventPublisher* event_publisher, LogManager* log_manager,
118*14675a02SAndroid Build Coastguard Worker std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory,
119*14675a02SAndroid Build Coastguard Worker const Flags* flags, ::fcp::client::http::HttpClient* http_client,
120*14675a02SAndroid Build Coastguard Worker const std::string& federated_service_uri, const std::string& api_key,
121*14675a02SAndroid Build Coastguard Worker const std::string& test_cert_path, absl::string_view population_name,
122*14675a02SAndroid Build Coastguard Worker absl::string_view retry_token, absl::string_view client_version,
123*14675a02SAndroid Build Coastguard Worker absl::string_view attestation_measurement,
124*14675a02SAndroid Build Coastguard Worker std::function<bool()> should_abort,
125*14675a02SAndroid Build Coastguard Worker const InterruptibleRunner::TimingConfig& timing_config,
126*14675a02SAndroid Build Coastguard Worker const int64_t grpc_channel_deadline_seconds,
127*14675a02SAndroid Build Coastguard Worker cache::ResourceCache* resource_cache)
128*14675a02SAndroid Build Coastguard Worker : GrpcFederatedProtocol(
129*14675a02SAndroid Build Coastguard Worker event_publisher, log_manager, std::move(secagg_runner_factory), flags,
130*14675a02SAndroid Build Coastguard Worker http_client,
131*14675a02SAndroid Build Coastguard Worker std::make_unique<GrpcBidiStream>(
132*14675a02SAndroid Build Coastguard Worker federated_service_uri, api_key, std::string(population_name),
133*14675a02SAndroid Build Coastguard Worker grpc_channel_deadline_seconds, test_cert_path),
134*14675a02SAndroid Build Coastguard Worker population_name, retry_token, client_version, attestation_measurement,
135*14675a02SAndroid Build Coastguard Worker should_abort, absl::BitGen(), timing_config, resource_cache) {}
136*14675a02SAndroid Build Coastguard Worker
GrpcFederatedProtocol(EventPublisher * event_publisher,LogManager * log_manager,std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory,const Flags * flags,::fcp::client::http::HttpClient * http_client,std::unique_ptr<GrpcBidiStreamInterface> grpc_bidi_stream,absl::string_view population_name,absl::string_view retry_token,absl::string_view client_version,absl::string_view attestation_measurement,std::function<bool ()> should_abort,absl::BitGen bit_gen,const InterruptibleRunner::TimingConfig & timing_config,cache::ResourceCache * resource_cache)137*14675a02SAndroid Build Coastguard Worker GrpcFederatedProtocol::GrpcFederatedProtocol(
138*14675a02SAndroid Build Coastguard Worker EventPublisher* event_publisher, LogManager* log_manager,
139*14675a02SAndroid Build Coastguard Worker std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory,
140*14675a02SAndroid Build Coastguard Worker const Flags* flags, ::fcp::client::http::HttpClient* http_client,
141*14675a02SAndroid Build Coastguard Worker std::unique_ptr<GrpcBidiStreamInterface> grpc_bidi_stream,
142*14675a02SAndroid Build Coastguard Worker absl::string_view population_name, absl::string_view retry_token,
143*14675a02SAndroid Build Coastguard Worker absl::string_view client_version, absl::string_view attestation_measurement,
144*14675a02SAndroid Build Coastguard Worker std::function<bool()> should_abort, absl::BitGen bit_gen,
145*14675a02SAndroid Build Coastguard Worker const InterruptibleRunner::TimingConfig& timing_config,
146*14675a02SAndroid Build Coastguard Worker cache::ResourceCache* resource_cache)
147*14675a02SAndroid Build Coastguard Worker : object_state_(ObjectState::kInitialized),
148*14675a02SAndroid Build Coastguard Worker event_publisher_(event_publisher),
149*14675a02SAndroid Build Coastguard Worker log_manager_(log_manager),
150*14675a02SAndroid Build Coastguard Worker secagg_runner_factory_(std::move(secagg_runner_factory)),
151*14675a02SAndroid Build Coastguard Worker flags_(flags),
152*14675a02SAndroid Build Coastguard Worker http_client_(http_client),
153*14675a02SAndroid Build Coastguard Worker grpc_bidi_stream_(std::move(grpc_bidi_stream)),
154*14675a02SAndroid Build Coastguard Worker population_name_(population_name),
155*14675a02SAndroid Build Coastguard Worker retry_token_(retry_token),
156*14675a02SAndroid Build Coastguard Worker client_version_(client_version),
157*14675a02SAndroid Build Coastguard Worker attestation_measurement_(attestation_measurement),
158*14675a02SAndroid Build Coastguard Worker bit_gen_(std::move(bit_gen)),
159*14675a02SAndroid Build Coastguard Worker resource_cache_(resource_cache) {
160*14675a02SAndroid Build Coastguard Worker interruptible_runner_ = std::make_unique<InterruptibleRunner>(
161*14675a02SAndroid Build Coastguard Worker log_manager, should_abort, timing_config,
162*14675a02SAndroid Build Coastguard Worker InterruptibleRunner::DiagnosticsConfig{
163*14675a02SAndroid Build Coastguard Worker .interrupted = ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_GRPC,
164*14675a02SAndroid Build Coastguard Worker .interrupt_timeout =
165*14675a02SAndroid Build Coastguard Worker ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_GRPC_TIMED_OUT,
166*14675a02SAndroid Build Coastguard Worker .interrupted_extended = ProdDiagCode::
167*14675a02SAndroid Build Coastguard Worker BACKGROUND_TRAINING_INTERRUPT_GRPC_EXTENDED_COMPLETED,
168*14675a02SAndroid Build Coastguard Worker .interrupt_timeout_extended = ProdDiagCode::
169*14675a02SAndroid Build Coastguard Worker BACKGROUND_TRAINING_INTERRUPT_GRPC_EXTENDED_TIMED_OUT});
170*14675a02SAndroid Build Coastguard Worker // Note that we could cast the provided error codes to absl::StatusCode
171*14675a02SAndroid Build Coastguard Worker // values here. However, that means we'd have to handle the case when
172*14675a02SAndroid Build Coastguard Worker // invalid integers that don't map to a StatusCode enum are provided in the
173*14675a02SAndroid Build Coastguard Worker // flag here. Instead, we cast absl::StatusCodes to int32_t each time we
174*14675a02SAndroid Build Coastguard Worker // compare them with the flag-provided list of codes, which means we never
175*14675a02SAndroid Build Coastguard Worker // have to worry about invalid flag values (besides the fact that invalid
176*14675a02SAndroid Build Coastguard Worker // values will be silently ignored, which could make it harder to realize when
177*14675a02SAndroid Build Coastguard Worker // flag is misconfigured).
178*14675a02SAndroid Build Coastguard Worker const std::vector<int32_t>& error_codes =
179*14675a02SAndroid Build Coastguard Worker flags->federated_training_permanent_error_codes();
180*14675a02SAndroid Build Coastguard Worker federated_training_permanent_error_codes_ =
181*14675a02SAndroid Build Coastguard Worker absl::flat_hash_set<int32_t>(error_codes.begin(), error_codes.end());
182*14675a02SAndroid Build Coastguard Worker }
183*14675a02SAndroid Build Coastguard Worker
~GrpcFederatedProtocol()184*14675a02SAndroid Build Coastguard Worker GrpcFederatedProtocol::~GrpcFederatedProtocol() { grpc_bidi_stream_->Close(); }
185*14675a02SAndroid Build Coastguard Worker
Send(google::internal::federatedml::v2::ClientStreamMessage * client_stream_message)186*14675a02SAndroid Build Coastguard Worker absl::Status GrpcFederatedProtocol::Send(
187*14675a02SAndroid Build Coastguard Worker google::internal::federatedml::v2::ClientStreamMessage*
188*14675a02SAndroid Build Coastguard Worker client_stream_message) {
189*14675a02SAndroid Build Coastguard Worker // Note that this stopwatch measurement may not fully measure the time it
190*14675a02SAndroid Build Coastguard Worker // takes to send all of the data, as it may return before all data was written
191*14675a02SAndroid Build Coastguard Worker // to the network socket. It's the best estimate we can provide though.
192*14675a02SAndroid Build Coastguard Worker auto started_stopwatch = network_stopwatch_->Start();
193*14675a02SAndroid Build Coastguard Worker FCP_RETURN_IF_ERROR(interruptible_runner_->Run(
194*14675a02SAndroid Build Coastguard Worker [this, &client_stream_message]() {
195*14675a02SAndroid Build Coastguard Worker return this->grpc_bidi_stream_->Send(client_stream_message);
196*14675a02SAndroid Build Coastguard Worker },
197*14675a02SAndroid Build Coastguard Worker [this]() { this->grpc_bidi_stream_->Close(); }));
198*14675a02SAndroid Build Coastguard Worker return absl::OkStatus();
199*14675a02SAndroid Build Coastguard Worker }
200*14675a02SAndroid Build Coastguard Worker
Receive(google::internal::federatedml::v2::ServerStreamMessage * server_stream_message)201*14675a02SAndroid Build Coastguard Worker absl::Status GrpcFederatedProtocol::Receive(
202*14675a02SAndroid Build Coastguard Worker google::internal::federatedml::v2::ServerStreamMessage*
203*14675a02SAndroid Build Coastguard Worker server_stream_message) {
204*14675a02SAndroid Build Coastguard Worker // Note that this stopwatch measurement will generally include time spent
205*14675a02SAndroid Build Coastguard Worker // waiting for the server to return a response (i.e. idle time rather than the
206*14675a02SAndroid Build Coastguard Worker // true time it takes to send/receive data on the network). It's the best
207*14675a02SAndroid Build Coastguard Worker // estimate we can provide though.
208*14675a02SAndroid Build Coastguard Worker auto started_stopwatch = network_stopwatch_->Start();
209*14675a02SAndroid Build Coastguard Worker FCP_RETURN_IF_ERROR(interruptible_runner_->Run(
210*14675a02SAndroid Build Coastguard Worker [this, &server_stream_message]() {
211*14675a02SAndroid Build Coastguard Worker return grpc_bidi_stream_->Receive(server_stream_message);
212*14675a02SAndroid Build Coastguard Worker },
213*14675a02SAndroid Build Coastguard Worker [this]() { this->grpc_bidi_stream_->Close(); }));
214*14675a02SAndroid Build Coastguard Worker return absl::OkStatus();
215*14675a02SAndroid Build Coastguard Worker }
216*14675a02SAndroid Build Coastguard Worker
CreateProtocolOptionsRequest(bool should_ack_checkin) const217*14675a02SAndroid Build Coastguard Worker ProtocolOptionsRequest GrpcFederatedProtocol::CreateProtocolOptionsRequest(
218*14675a02SAndroid Build Coastguard Worker bool should_ack_checkin) const {
219*14675a02SAndroid Build Coastguard Worker ProtocolOptionsRequest request;
220*14675a02SAndroid Build Coastguard Worker request.set_should_ack_checkin(should_ack_checkin);
221*14675a02SAndroid Build Coastguard Worker request.set_supports_http_download(http_client_ != nullptr);
222*14675a02SAndroid Build Coastguard Worker request.set_supports_eligibility_eval_http_download(
223*14675a02SAndroid Build Coastguard Worker http_client_ != nullptr &&
224*14675a02SAndroid Build Coastguard Worker flags_->enable_grpc_with_eligibility_eval_http_resource_support());
225*14675a02SAndroid Build Coastguard Worker
226*14675a02SAndroid Build Coastguard Worker // Note that we set this field for both eligibility eval checkin requests
227*14675a02SAndroid Build Coastguard Worker // and regular checkin requests. Even though eligibility eval tasks do not
228*14675a02SAndroid Build Coastguard Worker // have any aggregation phase, we still advertise the client's support for
229*14675a02SAndroid Build Coastguard Worker // Secure Aggregation during the eligibility eval checkin phase. We do
230*14675a02SAndroid Build Coastguard Worker // this because it doesn't hurt anything, and because letting the server
231*14675a02SAndroid Build Coastguard Worker // know whether client supports SecAgg sooner rather than later in the
232*14675a02SAndroid Build Coastguard Worker // protocol seems to provide maximum flexibility if the server ever were
233*14675a02SAndroid Build Coastguard Worker // to use that information at this stage of the protocol in the future.
234*14675a02SAndroid Build Coastguard Worker request.mutable_side_channels()
235*14675a02SAndroid Build Coastguard Worker ->mutable_secure_aggregation()
236*14675a02SAndroid Build Coastguard Worker ->add_client_variant(secagg::SECAGG_CLIENT_VARIANT_NATIVE_V1);
237*14675a02SAndroid Build Coastguard Worker request.mutable_supported_http_compression_formats()->Add(
238*14675a02SAndroid Build Coastguard Worker HttpCompressionFormat::HTTP_COMPRESSION_FORMAT_GZIP);
239*14675a02SAndroid Build Coastguard Worker return request;
240*14675a02SAndroid Build Coastguard Worker }
241*14675a02SAndroid Build Coastguard Worker
SendEligibilityEvalCheckinRequest()242*14675a02SAndroid Build Coastguard Worker absl::Status GrpcFederatedProtocol::SendEligibilityEvalCheckinRequest() {
243*14675a02SAndroid Build Coastguard Worker ClientStreamMessage client_stream_message;
244*14675a02SAndroid Build Coastguard Worker EligibilityEvalCheckinRequest* eligibility_checkin_request =
245*14675a02SAndroid Build Coastguard Worker client_stream_message.mutable_eligibility_eval_checkin_request();
246*14675a02SAndroid Build Coastguard Worker eligibility_checkin_request->set_population_name(population_name_);
247*14675a02SAndroid Build Coastguard Worker eligibility_checkin_request->set_retry_token(retry_token_);
248*14675a02SAndroid Build Coastguard Worker eligibility_checkin_request->set_client_version(client_version_);
249*14675a02SAndroid Build Coastguard Worker eligibility_checkin_request->set_attestation_measurement(
250*14675a02SAndroid Build Coastguard Worker attestation_measurement_);
251*14675a02SAndroid Build Coastguard Worker *eligibility_checkin_request->mutable_protocol_options_request() =
252*14675a02SAndroid Build Coastguard Worker CreateProtocolOptionsRequest(
253*14675a02SAndroid Build Coastguard Worker /* should_ack_checkin=*/true);
254*14675a02SAndroid Build Coastguard Worker
255*14675a02SAndroid Build Coastguard Worker return Send(&client_stream_message);
256*14675a02SAndroid Build Coastguard Worker }
257*14675a02SAndroid Build Coastguard Worker
SendCheckinRequest(const std::optional<TaskEligibilityInfo> & task_eligibility_info)258*14675a02SAndroid Build Coastguard Worker absl::Status GrpcFederatedProtocol::SendCheckinRequest(
259*14675a02SAndroid Build Coastguard Worker const std::optional<TaskEligibilityInfo>& task_eligibility_info) {
260*14675a02SAndroid Build Coastguard Worker ClientStreamMessage client_stream_message;
261*14675a02SAndroid Build Coastguard Worker CheckinRequest* checkin_request =
262*14675a02SAndroid Build Coastguard Worker client_stream_message.mutable_checkin_request();
263*14675a02SAndroid Build Coastguard Worker checkin_request->set_population_name(population_name_);
264*14675a02SAndroid Build Coastguard Worker checkin_request->set_retry_token(retry_token_);
265*14675a02SAndroid Build Coastguard Worker checkin_request->set_client_version(client_version_);
266*14675a02SAndroid Build Coastguard Worker checkin_request->set_attestation_measurement(attestation_measurement_);
267*14675a02SAndroid Build Coastguard Worker *checkin_request->mutable_protocol_options_request() =
268*14675a02SAndroid Build Coastguard Worker CreateProtocolOptionsRequest(/* should_ack_checkin=*/false);
269*14675a02SAndroid Build Coastguard Worker
270*14675a02SAndroid Build Coastguard Worker if (task_eligibility_info.has_value()) {
271*14675a02SAndroid Build Coastguard Worker *checkin_request->mutable_task_eligibility_info() = *task_eligibility_info;
272*14675a02SAndroid Build Coastguard Worker }
273*14675a02SAndroid Build Coastguard Worker
274*14675a02SAndroid Build Coastguard Worker return Send(&client_stream_message);
275*14675a02SAndroid Build Coastguard Worker }
276*14675a02SAndroid Build Coastguard Worker
ReceiveCheckinRequestAck()277*14675a02SAndroid Build Coastguard Worker absl::Status GrpcFederatedProtocol::ReceiveCheckinRequestAck() {
278*14675a02SAndroid Build Coastguard Worker // Wait for a CheckinRequestAck.
279*14675a02SAndroid Build Coastguard Worker ServerStreamMessage server_stream_message;
280*14675a02SAndroid Build Coastguard Worker absl::Status receive_status = Receive(&server_stream_message);
281*14675a02SAndroid Build Coastguard Worker if (receive_status.code() == absl::StatusCode::kNotFound) {
282*14675a02SAndroid Build Coastguard Worker FCP_LOG(INFO) << "Server responded NOT_FOUND to checkin request, "
283*14675a02SAndroid Build Coastguard Worker "population name '"
284*14675a02SAndroid Build Coastguard Worker << population_name_ << "' is likely incorrect.";
285*14675a02SAndroid Build Coastguard Worker }
286*14675a02SAndroid Build Coastguard Worker FCP_RETURN_IF_ERROR(receive_status);
287*14675a02SAndroid Build Coastguard Worker
288*14675a02SAndroid Build Coastguard Worker if (!server_stream_message.has_checkin_request_ack()) {
289*14675a02SAndroid Build Coastguard Worker log_manager_->LogDiag(
290*14675a02SAndroid Build Coastguard Worker ProdDiagCode::
291*14675a02SAndroid Build Coastguard Worker BACKGROUND_TRAINING_CHECKIN_REQUEST_ACK_EXPECTED_BUT_NOT_RECVD);
292*14675a02SAndroid Build Coastguard Worker return absl::UnimplementedError(
293*14675a02SAndroid Build Coastguard Worker "Requested but did not receive CheckinRequestAck");
294*14675a02SAndroid Build Coastguard Worker }
295*14675a02SAndroid Build Coastguard Worker log_manager_->LogDiag(
296*14675a02SAndroid Build Coastguard Worker ProdDiagCode::BACKGROUND_TRAINING_CHECKIN_REQUEST_ACK_RECEIVED);
297*14675a02SAndroid Build Coastguard Worker // Process the received CheckinRequestAck message.
298*14675a02SAndroid Build Coastguard Worker const CheckinRequestAck& checkin_request_ack =
299*14675a02SAndroid Build Coastguard Worker server_stream_message.checkin_request_ack();
300*14675a02SAndroid Build Coastguard Worker if (!checkin_request_ack.has_retry_window_if_accepted() ||
301*14675a02SAndroid Build Coastguard Worker !checkin_request_ack.has_retry_window_if_rejected()) {
302*14675a02SAndroid Build Coastguard Worker return absl::UnimplementedError(
303*14675a02SAndroid Build Coastguard Worker "Received CheckinRequestAck message with missing retry windows");
304*14675a02SAndroid Build Coastguard Worker }
305*14675a02SAndroid Build Coastguard Worker // Upon receiving the server's RetryWindows we immediately choose a concrete
306*14675a02SAndroid Build Coastguard Worker // target timestamp to retry at. This ensures that a) clients of this class
307*14675a02SAndroid Build Coastguard Worker // don't have to implement the logic to select a timestamp from a min/max
308*14675a02SAndroid Build Coastguard Worker // range themselves, b) we tell clients of this class to come back at exactly
309*14675a02SAndroid Build Coastguard Worker // a point in time the server intended us to come at (i.e. "now +
310*14675a02SAndroid Build Coastguard Worker // server_specified_retry_period", and not a point in time that is partly
311*14675a02SAndroid Build Coastguard Worker // determined by how long the remaining protocol interactions (e.g. training
312*14675a02SAndroid Build Coastguard Worker // and results upload) will take (i.e. "now +
313*14675a02SAndroid Build Coastguard Worker // duration_of_remaining_protocol_interactions +
314*14675a02SAndroid Build Coastguard Worker // server_specified_retry_period").
315*14675a02SAndroid Build Coastguard Worker checkin_request_ack_info_ = CheckinRequestAckInfo{
316*14675a02SAndroid Build Coastguard Worker .retry_info_if_rejected =
317*14675a02SAndroid Build Coastguard Worker RetryTimeAndToken{
318*14675a02SAndroid Build Coastguard Worker PickRetryTimeFromRange(
319*14675a02SAndroid Build Coastguard Worker checkin_request_ack.retry_window_if_rejected().delay_min(),
320*14675a02SAndroid Build Coastguard Worker checkin_request_ack.retry_window_if_rejected().delay_max(),
321*14675a02SAndroid Build Coastguard Worker bit_gen_),
322*14675a02SAndroid Build Coastguard Worker checkin_request_ack.retry_window_if_rejected().retry_token()},
323*14675a02SAndroid Build Coastguard Worker .retry_info_if_accepted = RetryTimeAndToken{
324*14675a02SAndroid Build Coastguard Worker PickRetryTimeFromRange(
325*14675a02SAndroid Build Coastguard Worker checkin_request_ack.retry_window_if_accepted().delay_min(),
326*14675a02SAndroid Build Coastguard Worker checkin_request_ack.retry_window_if_accepted().delay_max(),
327*14675a02SAndroid Build Coastguard Worker bit_gen_),
328*14675a02SAndroid Build Coastguard Worker checkin_request_ack.retry_window_if_accepted().retry_token()}};
329*14675a02SAndroid Build Coastguard Worker return absl::OkStatus();
330*14675a02SAndroid Build Coastguard Worker }
331*14675a02SAndroid Build Coastguard Worker
332*14675a02SAndroid Build Coastguard Worker absl::StatusOr<FederatedProtocol::EligibilityEvalCheckinResult>
ReceiveEligibilityEvalCheckinResponse(absl::Time start_time,std::function<void (const EligibilityEvalTask &)> payload_uris_received_callback)333*14675a02SAndroid Build Coastguard Worker GrpcFederatedProtocol::ReceiveEligibilityEvalCheckinResponse(
334*14675a02SAndroid Build Coastguard Worker absl::Time start_time, std::function<void(const EligibilityEvalTask&)>
335*14675a02SAndroid Build Coastguard Worker payload_uris_received_callback) {
336*14675a02SAndroid Build Coastguard Worker ServerStreamMessage server_stream_message;
337*14675a02SAndroid Build Coastguard Worker FCP_RETURN_IF_ERROR(Receive(&server_stream_message));
338*14675a02SAndroid Build Coastguard Worker
339*14675a02SAndroid Build Coastguard Worker if (!server_stream_message.has_eligibility_eval_checkin_response()) {
340*14675a02SAndroid Build Coastguard Worker return absl::UnimplementedError(
341*14675a02SAndroid Build Coastguard Worker absl::StrCat("Bad response to EligibilityEvalCheckinRequest; Expected "
342*14675a02SAndroid Build Coastguard Worker "EligibilityEvalCheckinResponse but got ",
343*14675a02SAndroid Build Coastguard Worker server_stream_message.kind_case(), "."));
344*14675a02SAndroid Build Coastguard Worker }
345*14675a02SAndroid Build Coastguard Worker
346*14675a02SAndroid Build Coastguard Worker const EligibilityEvalCheckinResponse& eligibility_checkin_response =
347*14675a02SAndroid Build Coastguard Worker server_stream_message.eligibility_eval_checkin_response();
348*14675a02SAndroid Build Coastguard Worker switch (eligibility_checkin_response.checkin_result_case()) {
349*14675a02SAndroid Build Coastguard Worker case EligibilityEvalCheckinResponse::kEligibilityEvalPayload: {
350*14675a02SAndroid Build Coastguard Worker const EligibilityEvalPayload& eligibility_eval_payload =
351*14675a02SAndroid Build Coastguard Worker eligibility_checkin_response.eligibility_eval_payload();
352*14675a02SAndroid Build Coastguard Worker object_state_ = ObjectState::kEligibilityEvalEnabled;
353*14675a02SAndroid Build Coastguard Worker EligibilityEvalTask result{.execution_id =
354*14675a02SAndroid Build Coastguard Worker eligibility_eval_payload.execution_id()};
355*14675a02SAndroid Build Coastguard Worker
356*14675a02SAndroid Build Coastguard Worker payload_uris_received_callback(result);
357*14675a02SAndroid Build Coastguard Worker
358*14675a02SAndroid Build Coastguard Worker PlanAndCheckpointPayloads payloads;
359*14675a02SAndroid Build Coastguard Worker if (http_client_ == nullptr ||
360*14675a02SAndroid Build Coastguard Worker !flags_->enable_grpc_with_eligibility_eval_http_resource_support()) {
361*14675a02SAndroid Build Coastguard Worker result.payloads = {
362*14675a02SAndroid Build Coastguard Worker .plan = eligibility_eval_payload.plan(),
363*14675a02SAndroid Build Coastguard Worker .checkpoint = eligibility_eval_payload.init_checkpoint()};
364*14675a02SAndroid Build Coastguard Worker } else {
365*14675a02SAndroid Build Coastguard Worker // Fetch the task resources, returning any errors that may be
366*14675a02SAndroid Build Coastguard Worker // encountered in the process.
367*14675a02SAndroid Build Coastguard Worker FCP_ASSIGN_OR_RETURN(
368*14675a02SAndroid Build Coastguard Worker result.payloads,
369*14675a02SAndroid Build Coastguard Worker FetchTaskResources(
370*14675a02SAndroid Build Coastguard Worker {.plan =
371*14675a02SAndroid Build Coastguard Worker {
372*14675a02SAndroid Build Coastguard Worker .has_uri =
373*14675a02SAndroid Build Coastguard Worker eligibility_eval_payload.has_plan_resource(),
374*14675a02SAndroid Build Coastguard Worker .uri = eligibility_eval_payload.plan_resource().uri(),
375*14675a02SAndroid Build Coastguard Worker .data = eligibility_eval_payload.plan(),
376*14675a02SAndroid Build Coastguard Worker .client_cache_id =
377*14675a02SAndroid Build Coastguard Worker eligibility_eval_payload.plan_resource()
378*14675a02SAndroid Build Coastguard Worker .client_cache_id(),
379*14675a02SAndroid Build Coastguard Worker .max_age = TimeUtil::ConvertProtoToAbslDuration(
380*14675a02SAndroid Build Coastguard Worker eligibility_eval_payload.plan_resource()
381*14675a02SAndroid Build Coastguard Worker .max_age()),
382*14675a02SAndroid Build Coastguard Worker },
383*14675a02SAndroid Build Coastguard Worker .checkpoint = {
384*14675a02SAndroid Build Coastguard Worker .has_uri = eligibility_eval_payload
385*14675a02SAndroid Build Coastguard Worker .has_init_checkpoint_resource(),
386*14675a02SAndroid Build Coastguard Worker .uri = eligibility_eval_payload.init_checkpoint_resource()
387*14675a02SAndroid Build Coastguard Worker .uri(),
388*14675a02SAndroid Build Coastguard Worker .data = eligibility_eval_payload.init_checkpoint(),
389*14675a02SAndroid Build Coastguard Worker .client_cache_id =
390*14675a02SAndroid Build Coastguard Worker eligibility_eval_payload.init_checkpoint_resource()
391*14675a02SAndroid Build Coastguard Worker .client_cache_id(),
392*14675a02SAndroid Build Coastguard Worker .max_age = TimeUtil::ConvertProtoToAbslDuration(
393*14675a02SAndroid Build Coastguard Worker eligibility_eval_payload.init_checkpoint_resource()
394*14675a02SAndroid Build Coastguard Worker .max_age()),
395*14675a02SAndroid Build Coastguard Worker }}));
396*14675a02SAndroid Build Coastguard Worker }
397*14675a02SAndroid Build Coastguard Worker return std::move(result);
398*14675a02SAndroid Build Coastguard Worker }
399*14675a02SAndroid Build Coastguard Worker case EligibilityEvalCheckinResponse::kNoEligibilityEvalConfigured: {
400*14675a02SAndroid Build Coastguard Worker // Nothing to do...
401*14675a02SAndroid Build Coastguard Worker object_state_ = ObjectState::kEligibilityEvalDisabled;
402*14675a02SAndroid Build Coastguard Worker return EligibilityEvalDisabled{};
403*14675a02SAndroid Build Coastguard Worker }
404*14675a02SAndroid Build Coastguard Worker case EligibilityEvalCheckinResponse::kRejectionInfo: {
405*14675a02SAndroid Build Coastguard Worker object_state_ = ObjectState::kEligibilityEvalCheckinRejected;
406*14675a02SAndroid Build Coastguard Worker return Rejection{};
407*14675a02SAndroid Build Coastguard Worker }
408*14675a02SAndroid Build Coastguard Worker default:
409*14675a02SAndroid Build Coastguard Worker return absl::UnimplementedError(
410*14675a02SAndroid Build Coastguard Worker "Unrecognized EligibilityEvalCheckinResponse");
411*14675a02SAndroid Build Coastguard Worker }
412*14675a02SAndroid Build Coastguard Worker }
413*14675a02SAndroid Build Coastguard Worker
414*14675a02SAndroid Build Coastguard Worker absl::StatusOr<FederatedProtocol::CheckinResult>
ReceiveCheckinResponse(absl::Time start_time,std::function<void (const TaskAssignment &)> payload_uris_received_callback)415*14675a02SAndroid Build Coastguard Worker GrpcFederatedProtocol::ReceiveCheckinResponse(
416*14675a02SAndroid Build Coastguard Worker absl::Time start_time,
417*14675a02SAndroid Build Coastguard Worker std::function<void(const TaskAssignment&)> payload_uris_received_callback) {
418*14675a02SAndroid Build Coastguard Worker ServerStreamMessage server_stream_message;
419*14675a02SAndroid Build Coastguard Worker absl::Status receive_status = Receive(&server_stream_message);
420*14675a02SAndroid Build Coastguard Worker FCP_RETURN_IF_ERROR(receive_status);
421*14675a02SAndroid Build Coastguard Worker
422*14675a02SAndroid Build Coastguard Worker if (!server_stream_message.has_checkin_response()) {
423*14675a02SAndroid Build Coastguard Worker return absl::UnimplementedError(absl::StrCat(
424*14675a02SAndroid Build Coastguard Worker "Bad response to CheckinRequest; Expected CheckinResponse but got ",
425*14675a02SAndroid Build Coastguard Worker server_stream_message.kind_case(), "."));
426*14675a02SAndroid Build Coastguard Worker }
427*14675a02SAndroid Build Coastguard Worker
428*14675a02SAndroid Build Coastguard Worker const CheckinResponse& checkin_response =
429*14675a02SAndroid Build Coastguard Worker server_stream_message.checkin_response();
430*14675a02SAndroid Build Coastguard Worker
431*14675a02SAndroid Build Coastguard Worker execution_phase_id_ =
432*14675a02SAndroid Build Coastguard Worker checkin_response.has_acceptance_info()
433*14675a02SAndroid Build Coastguard Worker ? checkin_response.acceptance_info().execution_phase_id()
434*14675a02SAndroid Build Coastguard Worker : "";
435*14675a02SAndroid Build Coastguard Worker switch (checkin_response.checkin_result_case()) {
436*14675a02SAndroid Build Coastguard Worker case CheckinResponse::kAcceptanceInfo: {
437*14675a02SAndroid Build Coastguard Worker const auto& acceptance_info = checkin_response.acceptance_info();
438*14675a02SAndroid Build Coastguard Worker
439*14675a02SAndroid Build Coastguard Worker for (const auto& [k, v] : acceptance_info.side_channels())
440*14675a02SAndroid Build Coastguard Worker side_channels_[k] = v;
441*14675a02SAndroid Build Coastguard Worker side_channel_protocol_execution_info_ =
442*14675a02SAndroid Build Coastguard Worker acceptance_info.side_channel_protocol_execution_info();
443*14675a02SAndroid Build Coastguard Worker side_channel_protocol_options_response_ =
444*14675a02SAndroid Build Coastguard Worker checkin_response.protocol_options_response().side_channels();
445*14675a02SAndroid Build Coastguard Worker
446*14675a02SAndroid Build Coastguard Worker std::optional<SecAggInfo> sec_agg_info = std::nullopt;
447*14675a02SAndroid Build Coastguard Worker if (side_channel_protocol_execution_info_.has_secure_aggregation()) {
448*14675a02SAndroid Build Coastguard Worker sec_agg_info = SecAggInfo{
449*14675a02SAndroid Build Coastguard Worker .expected_number_of_clients =
450*14675a02SAndroid Build Coastguard Worker side_channel_protocol_execution_info_.secure_aggregation()
451*14675a02SAndroid Build Coastguard Worker .expected_number_of_clients(),
452*14675a02SAndroid Build Coastguard Worker .minimum_clients_in_server_visible_aggregate =
453*14675a02SAndroid Build Coastguard Worker side_channel_protocol_execution_info_.secure_aggregation()
454*14675a02SAndroid Build Coastguard Worker .minimum_clients_in_server_visible_aggregate()};
455*14675a02SAndroid Build Coastguard Worker }
456*14675a02SAndroid Build Coastguard Worker
457*14675a02SAndroid Build Coastguard Worker TaskAssignment result{
458*14675a02SAndroid Build Coastguard Worker .federated_select_uri_template =
459*14675a02SAndroid Build Coastguard Worker acceptance_info.federated_select_uri_info().uri_template(),
460*14675a02SAndroid Build Coastguard Worker .aggregation_session_id = acceptance_info.execution_phase_id(),
461*14675a02SAndroid Build Coastguard Worker .sec_agg_info = sec_agg_info};
462*14675a02SAndroid Build Coastguard Worker
463*14675a02SAndroid Build Coastguard Worker payload_uris_received_callback(result);
464*14675a02SAndroid Build Coastguard Worker
465*14675a02SAndroid Build Coastguard Worker PlanAndCheckpointPayloads payloads;
466*14675a02SAndroid Build Coastguard Worker if (http_client_ == nullptr) {
467*14675a02SAndroid Build Coastguard Worker result.payloads = {.plan = acceptance_info.plan(),
468*14675a02SAndroid Build Coastguard Worker .checkpoint = acceptance_info.init_checkpoint()};
469*14675a02SAndroid Build Coastguard Worker } else {
470*14675a02SAndroid Build Coastguard Worker // Fetch the task resources, returning any errors that may be
471*14675a02SAndroid Build Coastguard Worker // encountered in the process.
472*14675a02SAndroid Build Coastguard Worker FCP_ASSIGN_OR_RETURN(
473*14675a02SAndroid Build Coastguard Worker result.payloads,
474*14675a02SAndroid Build Coastguard Worker FetchTaskResources(
475*14675a02SAndroid Build Coastguard Worker {.plan =
476*14675a02SAndroid Build Coastguard Worker {
477*14675a02SAndroid Build Coastguard Worker .has_uri = acceptance_info.has_plan_resource(),
478*14675a02SAndroid Build Coastguard Worker .uri = acceptance_info.plan_resource().uri(),
479*14675a02SAndroid Build Coastguard Worker .data = acceptance_info.plan(),
480*14675a02SAndroid Build Coastguard Worker .client_cache_id =
481*14675a02SAndroid Build Coastguard Worker acceptance_info.plan_resource().client_cache_id(),
482*14675a02SAndroid Build Coastguard Worker .max_age = TimeUtil::ConvertProtoToAbslDuration(
483*14675a02SAndroid Build Coastguard Worker acceptance_info.plan_resource().max_age()),
484*14675a02SAndroid Build Coastguard Worker },
485*14675a02SAndroid Build Coastguard Worker .checkpoint = {
486*14675a02SAndroid Build Coastguard Worker .has_uri = acceptance_info.has_init_checkpoint_resource(),
487*14675a02SAndroid Build Coastguard Worker .uri = acceptance_info.init_checkpoint_resource().uri(),
488*14675a02SAndroid Build Coastguard Worker .data = acceptance_info.init_checkpoint(),
489*14675a02SAndroid Build Coastguard Worker .client_cache_id =
490*14675a02SAndroid Build Coastguard Worker acceptance_info.init_checkpoint_resource()
491*14675a02SAndroid Build Coastguard Worker .client_cache_id(),
492*14675a02SAndroid Build Coastguard Worker .max_age = TimeUtil::ConvertProtoToAbslDuration(
493*14675a02SAndroid Build Coastguard Worker acceptance_info.init_checkpoint_resource().max_age()),
494*14675a02SAndroid Build Coastguard Worker }}));
495*14675a02SAndroid Build Coastguard Worker }
496*14675a02SAndroid Build Coastguard Worker
497*14675a02SAndroid Build Coastguard Worker object_state_ = ObjectState::kCheckinAccepted;
498*14675a02SAndroid Build Coastguard Worker return result;
499*14675a02SAndroid Build Coastguard Worker }
500*14675a02SAndroid Build Coastguard Worker case CheckinResponse::kRejectionInfo: {
501*14675a02SAndroid Build Coastguard Worker object_state_ = ObjectState::kCheckinRejected;
502*14675a02SAndroid Build Coastguard Worker return Rejection{};
503*14675a02SAndroid Build Coastguard Worker }
504*14675a02SAndroid Build Coastguard Worker default:
505*14675a02SAndroid Build Coastguard Worker return absl::UnimplementedError("Unrecognized CheckinResponse");
506*14675a02SAndroid Build Coastguard Worker }
507*14675a02SAndroid Build Coastguard Worker }
508*14675a02SAndroid Build Coastguard Worker
509*14675a02SAndroid Build Coastguard Worker absl::StatusOr<FederatedProtocol::EligibilityEvalCheckinResult>
EligibilityEvalCheckin(std::function<void (const EligibilityEvalTask &)> payload_uris_received_callback)510*14675a02SAndroid Build Coastguard Worker GrpcFederatedProtocol::EligibilityEvalCheckin(
511*14675a02SAndroid Build Coastguard Worker std::function<void(const EligibilityEvalTask&)>
512*14675a02SAndroid Build Coastguard Worker payload_uris_received_callback) {
513*14675a02SAndroid Build Coastguard Worker FCP_CHECK(object_state_ == ObjectState::kInitialized)
514*14675a02SAndroid Build Coastguard Worker << "Invalid call sequence";
515*14675a02SAndroid Build Coastguard Worker object_state_ = ObjectState::kEligibilityEvalCheckinFailed;
516*14675a02SAndroid Build Coastguard Worker
517*14675a02SAndroid Build Coastguard Worker absl::Time start_time = absl::Now();
518*14675a02SAndroid Build Coastguard Worker
519*14675a02SAndroid Build Coastguard Worker // Send an EligibilityEvalCheckinRequest.
520*14675a02SAndroid Build Coastguard Worker absl::Status request_status = SendEligibilityEvalCheckinRequest();
521*14675a02SAndroid Build Coastguard Worker // See note about how we handle 'permanent' errors at the top of this file.
522*14675a02SAndroid Build Coastguard Worker UpdateObjectStateIfPermanentError(
523*14675a02SAndroid Build Coastguard Worker request_status, ObjectState::kEligibilityEvalCheckinFailedPermanentError);
524*14675a02SAndroid Build Coastguard Worker FCP_RETURN_IF_ERROR(request_status);
525*14675a02SAndroid Build Coastguard Worker
526*14675a02SAndroid Build Coastguard Worker // Receive a CheckinRequestAck.
527*14675a02SAndroid Build Coastguard Worker absl::Status ack_status = ReceiveCheckinRequestAck();
528*14675a02SAndroid Build Coastguard Worker UpdateObjectStateIfPermanentError(
529*14675a02SAndroid Build Coastguard Worker ack_status, ObjectState::kEligibilityEvalCheckinFailedPermanentError);
530*14675a02SAndroid Build Coastguard Worker FCP_RETURN_IF_ERROR(ack_status);
531*14675a02SAndroid Build Coastguard Worker
532*14675a02SAndroid Build Coastguard Worker // Receive + handle an EligibilityEvalCheckinResponse message, and update the
533*14675a02SAndroid Build Coastguard Worker // object state based on the received response.
534*14675a02SAndroid Build Coastguard Worker auto response = ReceiveEligibilityEvalCheckinResponse(
535*14675a02SAndroid Build Coastguard Worker start_time, payload_uris_received_callback);
536*14675a02SAndroid Build Coastguard Worker UpdateObjectStateIfPermanentError(
537*14675a02SAndroid Build Coastguard Worker response.status(),
538*14675a02SAndroid Build Coastguard Worker ObjectState::kEligibilityEvalCheckinFailedPermanentError);
539*14675a02SAndroid Build Coastguard Worker return response;
540*14675a02SAndroid Build Coastguard Worker }
541*14675a02SAndroid Build Coastguard Worker
542*14675a02SAndroid Build Coastguard Worker // This is not supported in gRPC federated protocol, we'll do nothing.
ReportEligibilityEvalError(absl::Status error_status)543*14675a02SAndroid Build Coastguard Worker void GrpcFederatedProtocol::ReportEligibilityEvalError(
544*14675a02SAndroid Build Coastguard Worker absl::Status error_status) {}
545*14675a02SAndroid Build Coastguard Worker
Checkin(const std::optional<TaskEligibilityInfo> & task_eligibility_info,std::function<void (const TaskAssignment &)> payload_uris_received_callback)546*14675a02SAndroid Build Coastguard Worker absl::StatusOr<FederatedProtocol::CheckinResult> GrpcFederatedProtocol::Checkin(
547*14675a02SAndroid Build Coastguard Worker const std::optional<TaskEligibilityInfo>& task_eligibility_info,
548*14675a02SAndroid Build Coastguard Worker std::function<void(const TaskAssignment&)> payload_uris_received_callback) {
549*14675a02SAndroid Build Coastguard Worker // Checkin(...) must follow an earlier call to EligibilityEvalCheckin() that
550*14675a02SAndroid Build Coastguard Worker // resulted in a CheckinResultPayload or an EligibilityEvalDisabled result.
551*14675a02SAndroid Build Coastguard Worker FCP_CHECK(object_state_ == ObjectState::kEligibilityEvalDisabled ||
552*14675a02SAndroid Build Coastguard Worker object_state_ == ObjectState::kEligibilityEvalEnabled)
553*14675a02SAndroid Build Coastguard Worker << "Checkin(...) called despite failed/rejected earlier "
554*14675a02SAndroid Build Coastguard Worker "EligibilityEvalCheckin";
555*14675a02SAndroid Build Coastguard Worker if (object_state_ == ObjectState::kEligibilityEvalEnabled) {
556*14675a02SAndroid Build Coastguard Worker FCP_CHECK(task_eligibility_info.has_value())
557*14675a02SAndroid Build Coastguard Worker << "Missing TaskEligibilityInfo despite receiving prior "
558*14675a02SAndroid Build Coastguard Worker "EligibilityEvalCheckin payload";
559*14675a02SAndroid Build Coastguard Worker } else {
560*14675a02SAndroid Build Coastguard Worker FCP_CHECK(!task_eligibility_info.has_value())
561*14675a02SAndroid Build Coastguard Worker << "Received TaskEligibilityInfo despite not receiving a prior "
562*14675a02SAndroid Build Coastguard Worker "EligibilityEvalCheckin payload";
563*14675a02SAndroid Build Coastguard Worker }
564*14675a02SAndroid Build Coastguard Worker
565*14675a02SAndroid Build Coastguard Worker object_state_ = ObjectState::kCheckinFailed;
566*14675a02SAndroid Build Coastguard Worker
567*14675a02SAndroid Build Coastguard Worker absl::Time start_time = absl::Now();
568*14675a02SAndroid Build Coastguard Worker // Send a CheckinRequest.
569*14675a02SAndroid Build Coastguard Worker absl::Status request_status = SendCheckinRequest(task_eligibility_info);
570*14675a02SAndroid Build Coastguard Worker // See note about how we handle 'permanent' errors at the top of this file.
571*14675a02SAndroid Build Coastguard Worker UpdateObjectStateIfPermanentError(request_status,
572*14675a02SAndroid Build Coastguard Worker ObjectState::kCheckinFailedPermanentError);
573*14675a02SAndroid Build Coastguard Worker FCP_RETURN_IF_ERROR(request_status);
574*14675a02SAndroid Build Coastguard Worker
575*14675a02SAndroid Build Coastguard Worker // Receive + handle a CheckinResponse message, and update the object state
576*14675a02SAndroid Build Coastguard Worker // based on the received response.
577*14675a02SAndroid Build Coastguard Worker auto response =
578*14675a02SAndroid Build Coastguard Worker ReceiveCheckinResponse(start_time, payload_uris_received_callback);
579*14675a02SAndroid Build Coastguard Worker UpdateObjectStateIfPermanentError(response.status(),
580*14675a02SAndroid Build Coastguard Worker ObjectState::kCheckinFailedPermanentError);
581*14675a02SAndroid Build Coastguard Worker return response;
582*14675a02SAndroid Build Coastguard Worker }
583*14675a02SAndroid Build Coastguard Worker
584*14675a02SAndroid Build Coastguard Worker absl::StatusOr<FederatedProtocol::MultipleTaskAssignments>
PerformMultipleTaskAssignments(const std::vector<std::string> & task_names)585*14675a02SAndroid Build Coastguard Worker GrpcFederatedProtocol::PerformMultipleTaskAssignments(
586*14675a02SAndroid Build Coastguard Worker const std::vector<std::string>& task_names) {
587*14675a02SAndroid Build Coastguard Worker return absl::UnimplementedError(
588*14675a02SAndroid Build Coastguard Worker "PerformMultipleTaskAssignments is not supported by "
589*14675a02SAndroid Build Coastguard Worker "GrpcFederatedProtocol.");
590*14675a02SAndroid Build Coastguard Worker }
591*14675a02SAndroid Build Coastguard Worker
ReportCompleted(ComputationResults results,absl::Duration plan_duration,std::optional<std::string> aggregation_session_id)592*14675a02SAndroid Build Coastguard Worker absl::Status GrpcFederatedProtocol::ReportCompleted(
593*14675a02SAndroid Build Coastguard Worker ComputationResults results, absl::Duration plan_duration,
594*14675a02SAndroid Build Coastguard Worker std::optional<std::string> aggregation_session_id) {
595*14675a02SAndroid Build Coastguard Worker FCP_LOG(INFO) << "Reporting outcome: " << static_cast<int>(engine::COMPLETED);
596*14675a02SAndroid Build Coastguard Worker FCP_CHECK(object_state_ == ObjectState::kCheckinAccepted)
597*14675a02SAndroid Build Coastguard Worker << "Invalid call sequence";
598*14675a02SAndroid Build Coastguard Worker object_state_ = ObjectState::kReportCalled;
599*14675a02SAndroid Build Coastguard Worker auto response = Report(std::move(results), engine::COMPLETED, plan_duration);
600*14675a02SAndroid Build Coastguard Worker // See note about how we handle 'permanent' errors at the top of this file.
601*14675a02SAndroid Build Coastguard Worker UpdateObjectStateIfPermanentError(response,
602*14675a02SAndroid Build Coastguard Worker ObjectState::kReportFailedPermanentError);
603*14675a02SAndroid Build Coastguard Worker return response;
604*14675a02SAndroid Build Coastguard Worker }
605*14675a02SAndroid Build Coastguard Worker
ReportNotCompleted(engine::PhaseOutcome phase_outcome,absl::Duration plan_duration,std::optional<std::string> aggregation_session_Id)606*14675a02SAndroid Build Coastguard Worker absl::Status GrpcFederatedProtocol::ReportNotCompleted(
607*14675a02SAndroid Build Coastguard Worker engine::PhaseOutcome phase_outcome, absl::Duration plan_duration,
608*14675a02SAndroid Build Coastguard Worker std::optional<std::string> aggregation_session_Id) {
609*14675a02SAndroid Build Coastguard Worker FCP_LOG(WARNING) << "Reporting outcome: " << static_cast<int>(phase_outcome);
610*14675a02SAndroid Build Coastguard Worker FCP_CHECK(object_state_ == ObjectState::kCheckinAccepted)
611*14675a02SAndroid Build Coastguard Worker << "Invalid call sequence";
612*14675a02SAndroid Build Coastguard Worker object_state_ = ObjectState::kReportCalled;
613*14675a02SAndroid Build Coastguard Worker ComputationResults results;
614*14675a02SAndroid Build Coastguard Worker results.emplace("tensorflow_checkpoint", "");
615*14675a02SAndroid Build Coastguard Worker auto response = Report(std::move(results), phase_outcome, plan_duration);
616*14675a02SAndroid Build Coastguard Worker // See note about how we handle 'permanent' errors at the top of this file.
617*14675a02SAndroid Build Coastguard Worker UpdateObjectStateIfPermanentError(response,
618*14675a02SAndroid Build Coastguard Worker ObjectState::kReportFailedPermanentError);
619*14675a02SAndroid Build Coastguard Worker return response;
620*14675a02SAndroid Build Coastguard Worker }
621*14675a02SAndroid Build Coastguard Worker
622*14675a02SAndroid Build Coastguard Worker class GrpcSecAggSendToServerImpl : public SecAggSendToServerBase {
623*14675a02SAndroid Build Coastguard Worker public:
GrpcSecAggSendToServerImpl(GrpcBidiStreamInterface * grpc_bidi_stream,const std::function<absl::Status (ClientToServerWrapperMessage *)> & report_func)624*14675a02SAndroid Build Coastguard Worker GrpcSecAggSendToServerImpl(
625*14675a02SAndroid Build Coastguard Worker GrpcBidiStreamInterface* grpc_bidi_stream,
626*14675a02SAndroid Build Coastguard Worker const std::function<absl::Status(ClientToServerWrapperMessage*)>&
627*14675a02SAndroid Build Coastguard Worker report_func)
628*14675a02SAndroid Build Coastguard Worker : grpc_bidi_stream_(grpc_bidi_stream), report_func_(report_func) {}
629*14675a02SAndroid Build Coastguard Worker ~GrpcSecAggSendToServerImpl() override = default;
630*14675a02SAndroid Build Coastguard Worker
Send(ClientToServerWrapperMessage * message)631*14675a02SAndroid Build Coastguard Worker void Send(ClientToServerWrapperMessage* message) override {
632*14675a02SAndroid Build Coastguard Worker // The commit message (MaskedInputRequest) must be piggy-backed onto the
633*14675a02SAndroid Build Coastguard Worker // ReportRequest message, the logic for which is encapsulated in
634*14675a02SAndroid Build Coastguard Worker // report_func_ so that it may be held in common between both accumulation
635*14675a02SAndroid Build Coastguard Worker // methods.
636*14675a02SAndroid Build Coastguard Worker if (message->message_content_case() ==
637*14675a02SAndroid Build Coastguard Worker ClientToServerWrapperMessage::MessageContentCase::
638*14675a02SAndroid Build Coastguard Worker kMaskedInputResponse) {
639*14675a02SAndroid Build Coastguard Worker auto status = report_func_(message);
640*14675a02SAndroid Build Coastguard Worker if (!status.ok())
641*14675a02SAndroid Build Coastguard Worker FCP_LOG(ERROR) << "Could not send ReportRequest: " << status;
642*14675a02SAndroid Build Coastguard Worker return;
643*14675a02SAndroid Build Coastguard Worker }
644*14675a02SAndroid Build Coastguard Worker ClientStreamMessage client_stream_message;
645*14675a02SAndroid Build Coastguard Worker client_stream_message.mutable_secure_aggregation_client_message()->Swap(
646*14675a02SAndroid Build Coastguard Worker message);
647*14675a02SAndroid Build Coastguard Worker auto bytes_to_upload = client_stream_message.ByteSizeLong();
648*14675a02SAndroid Build Coastguard Worker auto status = grpc_bidi_stream_->Send(&client_stream_message);
649*14675a02SAndroid Build Coastguard Worker if (status.ok()) {
650*14675a02SAndroid Build Coastguard Worker last_sent_message_size_ = bytes_to_upload;
651*14675a02SAndroid Build Coastguard Worker }
652*14675a02SAndroid Build Coastguard Worker }
653*14675a02SAndroid Build Coastguard Worker
654*14675a02SAndroid Build Coastguard Worker private:
655*14675a02SAndroid Build Coastguard Worker GrpcBidiStreamInterface* grpc_bidi_stream_;
656*14675a02SAndroid Build Coastguard Worker // SecAgg's output must be wrapped in a ReportRequest; because the report
657*14675a02SAndroid Build Coastguard Worker // logic is mostly generic, this lambda allows it to be shared between
658*14675a02SAndroid Build Coastguard Worker // aggregation types.
659*14675a02SAndroid Build Coastguard Worker const std::function<absl::Status(ClientToServerWrapperMessage*)>&
660*14675a02SAndroid Build Coastguard Worker report_func_;
661*14675a02SAndroid Build Coastguard Worker };
662*14675a02SAndroid Build Coastguard Worker
663*14675a02SAndroid Build Coastguard Worker class GrpcSecAggProtocolDelegate : public SecAggProtocolDelegate {
664*14675a02SAndroid Build Coastguard Worker public:
GrpcSecAggProtocolDelegate(absl::flat_hash_map<std::string,SideChannelExecutionInfo> side_channels,GrpcBidiStreamInterface * grpc_bidi_stream)665*14675a02SAndroid Build Coastguard Worker GrpcSecAggProtocolDelegate(
666*14675a02SAndroid Build Coastguard Worker absl::flat_hash_map<std::string, SideChannelExecutionInfo> side_channels,
667*14675a02SAndroid Build Coastguard Worker GrpcBidiStreamInterface* grpc_bidi_stream)
668*14675a02SAndroid Build Coastguard Worker : side_channels_(std::move(side_channels)),
669*14675a02SAndroid Build Coastguard Worker grpc_bidi_stream_(grpc_bidi_stream) {}
670*14675a02SAndroid Build Coastguard Worker
GetModulus(const std::string & key)671*14675a02SAndroid Build Coastguard Worker absl::StatusOr<uint64_t> GetModulus(const std::string& key) override {
672*14675a02SAndroid Build Coastguard Worker auto execution_info = side_channels_.find(key);
673*14675a02SAndroid Build Coastguard Worker if (execution_info == side_channels_.end())
674*14675a02SAndroid Build Coastguard Worker return absl::InternalError(
675*14675a02SAndroid Build Coastguard Worker absl::StrCat("Execution not found for aggregand: ", key));
676*14675a02SAndroid Build Coastguard Worker uint64_t modulus;
677*14675a02SAndroid Build Coastguard Worker auto secure_aggregand = execution_info->second.secure_aggregand();
678*14675a02SAndroid Build Coastguard Worker // TODO(team): Delete output_bitwidth support once
679*14675a02SAndroid Build Coastguard Worker // modulus is fully rolled out.
680*14675a02SAndroid Build Coastguard Worker if (secure_aggregand.modulus() > 0) {
681*14675a02SAndroid Build Coastguard Worker modulus = secure_aggregand.modulus();
682*14675a02SAndroid Build Coastguard Worker } else {
683*14675a02SAndroid Build Coastguard Worker // Note: we ignore vector.get_bitwidth() here, because (1)
684*14675a02SAndroid Build Coastguard Worker // it is only an upper bound on the *input* bitwidth,
685*14675a02SAndroid Build Coastguard Worker // based on the Tensorflow dtype, but (2) we have exact
686*14675a02SAndroid Build Coastguard Worker // *output* bitwidth information from the execution_info,
687*14675a02SAndroid Build Coastguard Worker // and that is what SecAgg needs.
688*14675a02SAndroid Build Coastguard Worker modulus = 1ULL << secure_aggregand.output_bitwidth();
689*14675a02SAndroid Build Coastguard Worker }
690*14675a02SAndroid Build Coastguard Worker return modulus;
691*14675a02SAndroid Build Coastguard Worker }
692*14675a02SAndroid Build Coastguard Worker
ReceiveServerMessage()693*14675a02SAndroid Build Coastguard Worker absl::StatusOr<secagg::ServerToClientWrapperMessage> ReceiveServerMessage()
694*14675a02SAndroid Build Coastguard Worker override {
695*14675a02SAndroid Build Coastguard Worker ServerStreamMessage server_stream_message;
696*14675a02SAndroid Build Coastguard Worker absl::Status receive_status =
697*14675a02SAndroid Build Coastguard Worker grpc_bidi_stream_->Receive(&server_stream_message);
698*14675a02SAndroid Build Coastguard Worker if (!receive_status.ok()) {
699*14675a02SAndroid Build Coastguard Worker return absl::Status(receive_status.code(),
700*14675a02SAndroid Build Coastguard Worker absl::StrCat("Error during SecAgg receive: ",
701*14675a02SAndroid Build Coastguard Worker receive_status.message()));
702*14675a02SAndroid Build Coastguard Worker }
703*14675a02SAndroid Build Coastguard Worker last_received_message_size_ = server_stream_message.ByteSizeLong();
704*14675a02SAndroid Build Coastguard Worker if (!server_stream_message.has_secure_aggregation_server_message()) {
705*14675a02SAndroid Build Coastguard Worker return absl::InternalError(
706*14675a02SAndroid Build Coastguard Worker absl::StrCat("Bad response to SecAgg protocol; Expected "
707*14675a02SAndroid Build Coastguard Worker "ServerToClientWrapperMessage but got ",
708*14675a02SAndroid Build Coastguard Worker server_stream_message.kind_case(), "."));
709*14675a02SAndroid Build Coastguard Worker }
710*14675a02SAndroid Build Coastguard Worker return server_stream_message.secure_aggregation_server_message();
711*14675a02SAndroid Build Coastguard Worker }
712*14675a02SAndroid Build Coastguard Worker
Abort()713*14675a02SAndroid Build Coastguard Worker void Abort() override { grpc_bidi_stream_->Close(); }
last_received_message_size()714*14675a02SAndroid Build Coastguard Worker size_t last_received_message_size() override {
715*14675a02SAndroid Build Coastguard Worker return last_received_message_size_;
716*14675a02SAndroid Build Coastguard Worker };
717*14675a02SAndroid Build Coastguard Worker
718*14675a02SAndroid Build Coastguard Worker private:
719*14675a02SAndroid Build Coastguard Worker absl::flat_hash_map<std::string, SideChannelExecutionInfo> side_channels_;
720*14675a02SAndroid Build Coastguard Worker GrpcBidiStreamInterface* grpc_bidi_stream_;
721*14675a02SAndroid Build Coastguard Worker size_t last_received_message_size_;
722*14675a02SAndroid Build Coastguard Worker };
723*14675a02SAndroid Build Coastguard Worker
ReportInternal(std::string tf_checkpoint,engine::PhaseOutcome phase_outcome,absl::Duration plan_duration,ClientToServerWrapperMessage * secagg_commit_message)724*14675a02SAndroid Build Coastguard Worker absl::Status GrpcFederatedProtocol::ReportInternal(
725*14675a02SAndroid Build Coastguard Worker std::string tf_checkpoint, engine::PhaseOutcome phase_outcome,
726*14675a02SAndroid Build Coastguard Worker absl::Duration plan_duration,
727*14675a02SAndroid Build Coastguard Worker ClientToServerWrapperMessage* secagg_commit_message) {
728*14675a02SAndroid Build Coastguard Worker ClientStreamMessage client_stream_message;
729*14675a02SAndroid Build Coastguard Worker auto report_request = client_stream_message.mutable_report_request();
730*14675a02SAndroid Build Coastguard Worker report_request->set_population_name(population_name_);
731*14675a02SAndroid Build Coastguard Worker report_request->set_execution_phase_id(execution_phase_id_);
732*14675a02SAndroid Build Coastguard Worker auto report = report_request->mutable_report();
733*14675a02SAndroid Build Coastguard Worker
734*14675a02SAndroid Build Coastguard Worker // 1. Include TF checkpoint and/or SecAgg commit message.
735*14675a02SAndroid Build Coastguard Worker report->set_update_checkpoint(std::move(tf_checkpoint));
736*14675a02SAndroid Build Coastguard Worker if (secagg_commit_message) {
737*14675a02SAndroid Build Coastguard Worker client_stream_message.mutable_secure_aggregation_client_message()->Swap(
738*14675a02SAndroid Build Coastguard Worker secagg_commit_message);
739*14675a02SAndroid Build Coastguard Worker }
740*14675a02SAndroid Build Coastguard Worker
741*14675a02SAndroid Build Coastguard Worker // 2. Include outcome of computation.
742*14675a02SAndroid Build Coastguard Worker report->set_status_code(phase_outcome == engine::COMPLETED
743*14675a02SAndroid Build Coastguard Worker ? google::rpc::OK
744*14675a02SAndroid Build Coastguard Worker : google::rpc::INTERNAL);
745*14675a02SAndroid Build Coastguard Worker
746*14675a02SAndroid Build Coastguard Worker // 3. Include client execution statistics, if any.
747*14675a02SAndroid Build Coastguard Worker ClientExecutionStats client_execution_stats;
748*14675a02SAndroid Build Coastguard Worker client_execution_stats.mutable_duration()->set_seconds(
749*14675a02SAndroid Build Coastguard Worker absl::IDivDuration(plan_duration, absl::Seconds(1), &plan_duration));
750*14675a02SAndroid Build Coastguard Worker client_execution_stats.mutable_duration()->set_nanos(static_cast<int32_t>(
751*14675a02SAndroid Build Coastguard Worker absl::IDivDuration(plan_duration, absl::Nanoseconds(1), &plan_duration)));
752*14675a02SAndroid Build Coastguard Worker report->add_serialized_train_event()->PackFrom(client_execution_stats);
753*14675a02SAndroid Build Coastguard Worker
754*14675a02SAndroid Build Coastguard Worker // 4. Send ReportRequest.
755*14675a02SAndroid Build Coastguard Worker
756*14675a02SAndroid Build Coastguard Worker // Note that we do not use the GrpcFederatedProtocol::Send(...) helper method
757*14675a02SAndroid Build Coastguard Worker // here, since we are already running within a call to
758*14675a02SAndroid Build Coastguard Worker // InterruptibleRunner::Run.
759*14675a02SAndroid Build Coastguard Worker const auto status = this->grpc_bidi_stream_->Send(&client_stream_message);
760*14675a02SAndroid Build Coastguard Worker if (!status.ok()) {
761*14675a02SAndroid Build Coastguard Worker return absl::Status(
762*14675a02SAndroid Build Coastguard Worker status.code(),
763*14675a02SAndroid Build Coastguard Worker absl::StrCat("Error sending ReportRequest: ", status.message()));
764*14675a02SAndroid Build Coastguard Worker }
765*14675a02SAndroid Build Coastguard Worker
766*14675a02SAndroid Build Coastguard Worker return absl::OkStatus();
767*14675a02SAndroid Build Coastguard Worker }
768*14675a02SAndroid Build Coastguard Worker
Report(ComputationResults results,engine::PhaseOutcome phase_outcome,absl::Duration plan_duration)769*14675a02SAndroid Build Coastguard Worker absl::Status GrpcFederatedProtocol::Report(ComputationResults results,
770*14675a02SAndroid Build Coastguard Worker engine::PhaseOutcome phase_outcome,
771*14675a02SAndroid Build Coastguard Worker absl::Duration plan_duration) {
772*14675a02SAndroid Build Coastguard Worker std::string tf_checkpoint;
773*14675a02SAndroid Build Coastguard Worker bool has_checkpoint;
774*14675a02SAndroid Build Coastguard Worker for (auto& [k, v] : results) {
775*14675a02SAndroid Build Coastguard Worker if (std::holds_alternative<TFCheckpoint>(v)) {
776*14675a02SAndroid Build Coastguard Worker tf_checkpoint = std::get<TFCheckpoint>(std::move(v));
777*14675a02SAndroid Build Coastguard Worker has_checkpoint = true;
778*14675a02SAndroid Build Coastguard Worker break;
779*14675a02SAndroid Build Coastguard Worker }
780*14675a02SAndroid Build Coastguard Worker }
781*14675a02SAndroid Build Coastguard Worker
782*14675a02SAndroid Build Coastguard Worker // This lambda allows for convenient reporting from within SecAgg's
783*14675a02SAndroid Build Coastguard Worker // SendToServerInterface::Send().
784*14675a02SAndroid Build Coastguard Worker std::function<absl::Status(ClientToServerWrapperMessage*)> report_lambda =
785*14675a02SAndroid Build Coastguard Worker [&](ClientToServerWrapperMessage* secagg_commit_message) -> absl::Status {
786*14675a02SAndroid Build Coastguard Worker return ReportInternal(std::move(tf_checkpoint), phase_outcome,
787*14675a02SAndroid Build Coastguard Worker plan_duration, secagg_commit_message);
788*14675a02SAndroid Build Coastguard Worker };
789*14675a02SAndroid Build Coastguard Worker
790*14675a02SAndroid Build Coastguard Worker // Run the Secure Aggregation protocol, if necessary.
791*14675a02SAndroid Build Coastguard Worker if (side_channel_protocol_execution_info_.has_secure_aggregation()) {
792*14675a02SAndroid Build Coastguard Worker auto secure_aggregation_protocol_execution_info =
793*14675a02SAndroid Build Coastguard Worker side_channel_protocol_execution_info_.secure_aggregation();
794*14675a02SAndroid Build Coastguard Worker auto expected_number_of_clients =
795*14675a02SAndroid Build Coastguard Worker secure_aggregation_protocol_execution_info.expected_number_of_clients();
796*14675a02SAndroid Build Coastguard Worker
797*14675a02SAndroid Build Coastguard Worker FCP_LOG(INFO) << "Reporting via Secure Aggregation";
798*14675a02SAndroid Build Coastguard Worker if (phase_outcome != engine::COMPLETED)
799*14675a02SAndroid Build Coastguard Worker return absl::InternalError(
800*14675a02SAndroid Build Coastguard Worker "Aborting the SecAgg protocol (no update was produced).");
801*14675a02SAndroid Build Coastguard Worker
802*14675a02SAndroid Build Coastguard Worker if (side_channel_protocol_options_response_.secure_aggregation()
803*14675a02SAndroid Build Coastguard Worker .client_variant() != secagg::SECAGG_CLIENT_VARIANT_NATIVE_V1) {
804*14675a02SAndroid Build Coastguard Worker log_manager_->LogDiag(
805*14675a02SAndroid Build Coastguard Worker ProdDiagCode::SECAGG_CLIENT_ERROR_UNSUPPORTED_VERSION);
806*14675a02SAndroid Build Coastguard Worker return absl::InternalError(absl::StrCat(
807*14675a02SAndroid Build Coastguard Worker "Unsupported SecAgg client variant: ",
808*14675a02SAndroid Build Coastguard Worker side_channel_protocol_options_response_.secure_aggregation()
809*14675a02SAndroid Build Coastguard Worker .client_variant()));
810*14675a02SAndroid Build Coastguard Worker }
811*14675a02SAndroid Build Coastguard Worker
812*14675a02SAndroid Build Coastguard Worker auto send_to_server_impl = std::make_unique<GrpcSecAggSendToServerImpl>(
813*14675a02SAndroid Build Coastguard Worker grpc_bidi_stream_.get(), report_lambda);
814*14675a02SAndroid Build Coastguard Worker auto secagg_event_publisher = event_publisher_->secagg_event_publisher();
815*14675a02SAndroid Build Coastguard Worker FCP_CHECK(secagg_event_publisher)
816*14675a02SAndroid Build Coastguard Worker << "An implementation of "
817*14675a02SAndroid Build Coastguard Worker << "SecAggEventPublisher must be provided.";
818*14675a02SAndroid Build Coastguard Worker auto delegate = std::make_unique<GrpcSecAggProtocolDelegate>(
819*14675a02SAndroid Build Coastguard Worker side_channels_, grpc_bidi_stream_.get());
820*14675a02SAndroid Build Coastguard Worker std::unique_ptr<SecAggRunner> secagg_runner =
821*14675a02SAndroid Build Coastguard Worker secagg_runner_factory_->CreateSecAggRunner(
822*14675a02SAndroid Build Coastguard Worker std::move(send_to_server_impl), std::move(delegate),
823*14675a02SAndroid Build Coastguard Worker secagg_event_publisher, log_manager_, interruptible_runner_.get(),
824*14675a02SAndroid Build Coastguard Worker expected_number_of_clients,
825*14675a02SAndroid Build Coastguard Worker secure_aggregation_protocol_execution_info
826*14675a02SAndroid Build Coastguard Worker .minimum_surviving_clients_for_reconstruction());
827*14675a02SAndroid Build Coastguard Worker
828*14675a02SAndroid Build Coastguard Worker FCP_RETURN_IF_ERROR(secagg_runner->Run(std::move(results)));
829*14675a02SAndroid Build Coastguard Worker } else {
830*14675a02SAndroid Build Coastguard Worker // Report without secure aggregation.
831*14675a02SAndroid Build Coastguard Worker FCP_LOG(INFO) << "Reporting via Simple Aggregation";
832*14675a02SAndroid Build Coastguard Worker if (results.size() != 1 || !has_checkpoint) {
833*14675a02SAndroid Build Coastguard Worker return absl::InternalError(
834*14675a02SAndroid Build Coastguard Worker "Simple Aggregation aggregands have unexpected format.");
835*14675a02SAndroid Build Coastguard Worker }
836*14675a02SAndroid Build Coastguard Worker FCP_RETURN_IF_ERROR(interruptible_runner_->Run(
837*14675a02SAndroid Build Coastguard Worker [&report_lambda]() { return report_lambda(nullptr); },
838*14675a02SAndroid Build Coastguard Worker [this]() {
839*14675a02SAndroid Build Coastguard Worker // What about event_publisher_ and log_manager_?
840*14675a02SAndroid Build Coastguard Worker this->grpc_bidi_stream_->Close();
841*14675a02SAndroid Build Coastguard Worker }));
842*14675a02SAndroid Build Coastguard Worker }
843*14675a02SAndroid Build Coastguard Worker
844*14675a02SAndroid Build Coastguard Worker FCP_LOG(INFO) << "Finished reporting.";
845*14675a02SAndroid Build Coastguard Worker
846*14675a02SAndroid Build Coastguard Worker // Receive ReportResponse.
847*14675a02SAndroid Build Coastguard Worker ServerStreamMessage server_stream_message;
848*14675a02SAndroid Build Coastguard Worker absl::Status receive_status = Receive(&server_stream_message);
849*14675a02SAndroid Build Coastguard Worker if (receive_status.code() == absl::StatusCode::kAborted) {
850*14675a02SAndroid Build Coastguard Worker FCP_LOG(INFO) << "Server responded ABORTED.";
851*14675a02SAndroid Build Coastguard Worker } else if (receive_status.code() == absl::StatusCode::kCancelled) {
852*14675a02SAndroid Build Coastguard Worker FCP_LOG(INFO) << "Upload was cancelled by the client.";
853*14675a02SAndroid Build Coastguard Worker }
854*14675a02SAndroid Build Coastguard Worker if (!receive_status.ok()) {
855*14675a02SAndroid Build Coastguard Worker return absl::Status(
856*14675a02SAndroid Build Coastguard Worker receive_status.code(),
857*14675a02SAndroid Build Coastguard Worker absl::StrCat("Error after ReportRequest: ", receive_status.message()));
858*14675a02SAndroid Build Coastguard Worker }
859*14675a02SAndroid Build Coastguard Worker if (!server_stream_message.has_report_response()) {
860*14675a02SAndroid Build Coastguard Worker return absl::UnimplementedError(absl::StrCat(
861*14675a02SAndroid Build Coastguard Worker "Bad response to ReportRequest; Expected REPORT_RESPONSE but got ",
862*14675a02SAndroid Build Coastguard Worker server_stream_message.kind_case(), "."));
863*14675a02SAndroid Build Coastguard Worker }
864*14675a02SAndroid Build Coastguard Worker return absl::OkStatus();
865*14675a02SAndroid Build Coastguard Worker }
866*14675a02SAndroid Build Coastguard Worker
GetLatestRetryWindow()867*14675a02SAndroid Build Coastguard Worker RetryWindow GrpcFederatedProtocol::GetLatestRetryWindow() {
868*14675a02SAndroid Build Coastguard Worker // We explicitly enumerate all possible states here rather than using
869*14675a02SAndroid Build Coastguard Worker // "default", to ensure that when new states are added later on, the author
870*14675a02SAndroid Build Coastguard Worker // is forced to update this method and consider which is the correct
871*14675a02SAndroid Build Coastguard Worker // RetryWindow to return.
872*14675a02SAndroid Build Coastguard Worker switch (object_state_) {
873*14675a02SAndroid Build Coastguard Worker case ObjectState::kCheckinAccepted:
874*14675a02SAndroid Build Coastguard Worker case ObjectState::kReportCalled:
875*14675a02SAndroid Build Coastguard Worker // If a client makes it past the 'checkin acceptance' stage, we use the
876*14675a02SAndroid Build Coastguard Worker // 'accepted' RetryWindow unconditionally (unless a permanent error is
877*14675a02SAndroid Build Coastguard Worker // encountered). This includes cases where the checkin is accepted, but
878*14675a02SAndroid Build Coastguard Worker // the report request results in a (transient) error.
879*14675a02SAndroid Build Coastguard Worker FCP_CHECK(checkin_request_ack_info_.has_value());
880*14675a02SAndroid Build Coastguard Worker return GenerateRetryWindowFromRetryTimeAndToken(
881*14675a02SAndroid Build Coastguard Worker checkin_request_ack_info_->retry_info_if_accepted);
882*14675a02SAndroid Build Coastguard Worker case ObjectState::kEligibilityEvalCheckinRejected:
883*14675a02SAndroid Build Coastguard Worker case ObjectState::kEligibilityEvalDisabled:
884*14675a02SAndroid Build Coastguard Worker case ObjectState::kEligibilityEvalEnabled:
885*14675a02SAndroid Build Coastguard Worker case ObjectState::kCheckinRejected:
886*14675a02SAndroid Build Coastguard Worker FCP_CHECK(checkin_request_ack_info_.has_value());
887*14675a02SAndroid Build Coastguard Worker return GenerateRetryWindowFromRetryTimeAndToken(
888*14675a02SAndroid Build Coastguard Worker checkin_request_ack_info_->retry_info_if_rejected);
889*14675a02SAndroid Build Coastguard Worker case ObjectState::kInitialized:
890*14675a02SAndroid Build Coastguard Worker case ObjectState::kEligibilityEvalCheckinFailed:
891*14675a02SAndroid Build Coastguard Worker case ObjectState::kCheckinFailed:
892*14675a02SAndroid Build Coastguard Worker // If the flag is true, then we use the previously chosen absolute retry
893*14675a02SAndroid Build Coastguard Worker // time instead (if available).
894*14675a02SAndroid Build Coastguard Worker if (checkin_request_ack_info_.has_value()) {
895*14675a02SAndroid Build Coastguard Worker // If we already received a server-provided retry window, then use it.
896*14675a02SAndroid Build Coastguard Worker return GenerateRetryWindowFromRetryTimeAndToken(
897*14675a02SAndroid Build Coastguard Worker checkin_request_ack_info_->retry_info_if_rejected);
898*14675a02SAndroid Build Coastguard Worker }
899*14675a02SAndroid Build Coastguard Worker // Otherwise, we generate a retry window using the flag-provided transient
900*14675a02SAndroid Build Coastguard Worker // error retry period.
901*14675a02SAndroid Build Coastguard Worker return GenerateRetryWindowFromTargetDelay(
902*14675a02SAndroid Build Coastguard Worker absl::Seconds(
903*14675a02SAndroid Build Coastguard Worker flags_->federated_training_transient_errors_retry_delay_secs()),
904*14675a02SAndroid Build Coastguard Worker // NOLINTBEGIN(whitespace/line_length)
905*14675a02SAndroid Build Coastguard Worker flags_
906*14675a02SAndroid Build Coastguard Worker ->federated_training_transient_errors_retry_delay_jitter_percent(),
907*14675a02SAndroid Build Coastguard Worker // NOLINTEND
908*14675a02SAndroid Build Coastguard Worker bit_gen_);
909*14675a02SAndroid Build Coastguard Worker case ObjectState::kEligibilityEvalCheckinFailedPermanentError:
910*14675a02SAndroid Build Coastguard Worker case ObjectState::kCheckinFailedPermanentError:
911*14675a02SAndroid Build Coastguard Worker case ObjectState::kReportFailedPermanentError:
912*14675a02SAndroid Build Coastguard Worker // If we encountered a permanent error during the eligibility eval or
913*14675a02SAndroid Build Coastguard Worker // regular checkins, then we use the Flags-configured 'permanent error'
914*14675a02SAndroid Build Coastguard Worker // retry period. Note that we do so regardless of whether the server had,
915*14675a02SAndroid Build Coastguard Worker // by the time the permanent error was received, already returned a
916*14675a02SAndroid Build Coastguard Worker // CheckinRequestAck containing a set of retry windows. See note on error
917*14675a02SAndroid Build Coastguard Worker // handling at the top of this file.
918*14675a02SAndroid Build Coastguard Worker return GenerateRetryWindowFromTargetDelay(
919*14675a02SAndroid Build Coastguard Worker absl::Seconds(
920*14675a02SAndroid Build Coastguard Worker flags_->federated_training_permanent_errors_retry_delay_secs()),
921*14675a02SAndroid Build Coastguard Worker // NOLINTBEGIN(whitespace/line_length)
922*14675a02SAndroid Build Coastguard Worker flags_
923*14675a02SAndroid Build Coastguard Worker ->federated_training_permanent_errors_retry_delay_jitter_percent(),
924*14675a02SAndroid Build Coastguard Worker // NOLINTEND
925*14675a02SAndroid Build Coastguard Worker bit_gen_);
926*14675a02SAndroid Build Coastguard Worker case ObjectState::kMultipleTaskAssignmentsAccepted:
927*14675a02SAndroid Build Coastguard Worker case ObjectState::kMultipleTaskAssignmentsFailed:
928*14675a02SAndroid Build Coastguard Worker case ObjectState::kMultipleTaskAssignmentsFailedPermanentError:
929*14675a02SAndroid Build Coastguard Worker case ObjectState::kMultipleTaskAssignmentsNoAvailableTask:
930*14675a02SAndroid Build Coastguard Worker case ObjectState::kReportMultipleTaskPartialError:
931*14675a02SAndroid Build Coastguard Worker FCP_LOG(FATAL) << "Multi-task assignments is not supported by gRPC.";
932*14675a02SAndroid Build Coastguard Worker RetryWindow retry_window;
933*14675a02SAndroid Build Coastguard Worker return retry_window;
934*14675a02SAndroid Build Coastguard Worker }
935*14675a02SAndroid Build Coastguard Worker }
936*14675a02SAndroid Build Coastguard Worker
937*14675a02SAndroid Build Coastguard Worker // Converts the given RetryTimeAndToken to a zero-width RetryWindow (where
938*14675a02SAndroid Build Coastguard Worker // delay_min and delay_max are set to the same value), by converting the target
939*14675a02SAndroid Build Coastguard Worker // retry time to a delay relative to the current timestamp.
GenerateRetryWindowFromRetryTimeAndToken(const GrpcFederatedProtocol::RetryTimeAndToken & retry_info)940*14675a02SAndroid Build Coastguard Worker RetryWindow GrpcFederatedProtocol::GenerateRetryWindowFromRetryTimeAndToken(
941*14675a02SAndroid Build Coastguard Worker const GrpcFederatedProtocol::RetryTimeAndToken& retry_info) {
942*14675a02SAndroid Build Coastguard Worker // Generate a RetryWindow with delay_min and delay_max both set to the same
943*14675a02SAndroid Build Coastguard Worker // value.
944*14675a02SAndroid Build Coastguard Worker RetryWindow retry_window =
945*14675a02SAndroid Build Coastguard Worker GenerateRetryWindowFromRetryTime(retry_info.retry_time);
946*14675a02SAndroid Build Coastguard Worker retry_window.set_retry_token(retry_info.retry_token);
947*14675a02SAndroid Build Coastguard Worker return retry_window;
948*14675a02SAndroid Build Coastguard Worker }
949*14675a02SAndroid Build Coastguard Worker
UpdateObjectStateIfPermanentError(absl::Status status,GrpcFederatedProtocol::ObjectState permanent_error_object_state)950*14675a02SAndroid Build Coastguard Worker void GrpcFederatedProtocol::UpdateObjectStateIfPermanentError(
951*14675a02SAndroid Build Coastguard Worker absl::Status status,
952*14675a02SAndroid Build Coastguard Worker GrpcFederatedProtocol::ObjectState permanent_error_object_state) {
953*14675a02SAndroid Build Coastguard Worker if (federated_training_permanent_error_codes_.contains(
954*14675a02SAndroid Build Coastguard Worker static_cast<int32_t>(status.code()))) {
955*14675a02SAndroid Build Coastguard Worker object_state_ = permanent_error_object_state;
956*14675a02SAndroid Build Coastguard Worker }
957*14675a02SAndroid Build Coastguard Worker }
958*14675a02SAndroid Build Coastguard Worker
959*14675a02SAndroid Build Coastguard Worker absl::StatusOr<FederatedProtocol::PlanAndCheckpointPayloads>
FetchTaskResources(GrpcFederatedProtocol::TaskResources task_resources)960*14675a02SAndroid Build Coastguard Worker GrpcFederatedProtocol::FetchTaskResources(
961*14675a02SAndroid Build Coastguard Worker GrpcFederatedProtocol::TaskResources task_resources) {
962*14675a02SAndroid Build Coastguard Worker FCP_ASSIGN_OR_RETURN(UriOrInlineData plan_uri_or_data,
963*14675a02SAndroid Build Coastguard Worker ConvertResourceToUriOrInlineData(task_resources.plan));
964*14675a02SAndroid Build Coastguard Worker FCP_ASSIGN_OR_RETURN(
965*14675a02SAndroid Build Coastguard Worker UriOrInlineData checkpoint_uri_or_data,
966*14675a02SAndroid Build Coastguard Worker ConvertResourceToUriOrInlineData(task_resources.checkpoint));
967*14675a02SAndroid Build Coastguard Worker
968*14675a02SAndroid Build Coastguard Worker // Log a diag code if either resource is about to be downloaded via HTTP.
969*14675a02SAndroid Build Coastguard Worker if (!plan_uri_or_data.uri().uri.empty() ||
970*14675a02SAndroid Build Coastguard Worker !checkpoint_uri_or_data.uri().uri.empty()) {
971*14675a02SAndroid Build Coastguard Worker log_manager_->LogDiag(
972*14675a02SAndroid Build Coastguard Worker ProdDiagCode::HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_USES_HTTP);
973*14675a02SAndroid Build Coastguard Worker }
974*14675a02SAndroid Build Coastguard Worker
975*14675a02SAndroid Build Coastguard Worker // Fetch the plan and init checkpoint resources if they need to be fetched
976*14675a02SAndroid Build Coastguard Worker // (using the inline data instead if available).
977*14675a02SAndroid Build Coastguard Worker absl::StatusOr<
978*14675a02SAndroid Build Coastguard Worker std::vector<absl::StatusOr<::fcp::client::http::InMemoryHttpResponse>>>
979*14675a02SAndroid Build Coastguard Worker resource_responses;
980*14675a02SAndroid Build Coastguard Worker {
981*14675a02SAndroid Build Coastguard Worker auto started_stopwatch = network_stopwatch_->Start();
982*14675a02SAndroid Build Coastguard Worker resource_responses = ::fcp::client::http::FetchResourcesInMemory(
983*14675a02SAndroid Build Coastguard Worker *http_client_, *interruptible_runner_,
984*14675a02SAndroid Build Coastguard Worker {plan_uri_or_data, checkpoint_uri_or_data}, &http_bytes_downloaded_,
985*14675a02SAndroid Build Coastguard Worker &http_bytes_uploaded_, resource_cache_);
986*14675a02SAndroid Build Coastguard Worker }
987*14675a02SAndroid Build Coastguard Worker if (!resource_responses.ok()) {
988*14675a02SAndroid Build Coastguard Worker log_manager_->LogDiag(
989*14675a02SAndroid Build Coastguard Worker ProdDiagCode::
990*14675a02SAndroid Build Coastguard Worker HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_FAILED);
991*14675a02SAndroid Build Coastguard Worker return resource_responses.status();
992*14675a02SAndroid Build Coastguard Worker }
993*14675a02SAndroid Build Coastguard Worker auto& plan_data_response = (*resource_responses)[0];
994*14675a02SAndroid Build Coastguard Worker auto& checkpoint_data_response = (*resource_responses)[1];
995*14675a02SAndroid Build Coastguard Worker
996*14675a02SAndroid Build Coastguard Worker if (!plan_data_response.ok() || !checkpoint_data_response.ok()) {
997*14675a02SAndroid Build Coastguard Worker log_manager_->LogDiag(
998*14675a02SAndroid Build Coastguard Worker ProdDiagCode::
999*14675a02SAndroid Build Coastguard Worker HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_FAILED);
1000*14675a02SAndroid Build Coastguard Worker }
1001*14675a02SAndroid Build Coastguard Worker // Note: we forward any error during the fetching of the plan/checkpoint
1002*14675a02SAndroid Build Coastguard Worker // resources resources to the caller, which means that these error codes
1003*14675a02SAndroid Build Coastguard Worker // will be checked against the set of 'permanent' error codes, just like the
1004*14675a02SAndroid Build Coastguard Worker // errors in response to the protocol request are.
1005*14675a02SAndroid Build Coastguard Worker if (!plan_data_response.ok()) {
1006*14675a02SAndroid Build Coastguard Worker return absl::Status(plan_data_response.status().code(),
1007*14675a02SAndroid Build Coastguard Worker absl::StrCat("plan fetch failed: ",
1008*14675a02SAndroid Build Coastguard Worker plan_data_response.status().ToString()));
1009*14675a02SAndroid Build Coastguard Worker }
1010*14675a02SAndroid Build Coastguard Worker if (!checkpoint_data_response.ok()) {
1011*14675a02SAndroid Build Coastguard Worker return absl::Status(
1012*14675a02SAndroid Build Coastguard Worker checkpoint_data_response.status().code(),
1013*14675a02SAndroid Build Coastguard Worker absl::StrCat("checkpoint fetch failed: ",
1014*14675a02SAndroid Build Coastguard Worker checkpoint_data_response.status().ToString()));
1015*14675a02SAndroid Build Coastguard Worker }
1016*14675a02SAndroid Build Coastguard Worker if (!plan_uri_or_data.uri().uri.empty() ||
1017*14675a02SAndroid Build Coastguard Worker !checkpoint_uri_or_data.uri().uri.empty()) {
1018*14675a02SAndroid Build Coastguard Worker // We only want to log this diag code when we actually did fetch something
1019*14675a02SAndroid Build Coastguard Worker // via HTTP.
1020*14675a02SAndroid Build Coastguard Worker log_manager_->LogDiag(
1021*14675a02SAndroid Build Coastguard Worker ProdDiagCode::
1022*14675a02SAndroid Build Coastguard Worker HTTP_GRPC_PROTOCOL_REGULAR_TASK_RESOURCE_HTTP_FETCH_SUCCEEDED);
1023*14675a02SAndroid Build Coastguard Worker }
1024*14675a02SAndroid Build Coastguard Worker
1025*14675a02SAndroid Build Coastguard Worker return PlanAndCheckpointPayloads{plan_data_response->body,
1026*14675a02SAndroid Build Coastguard Worker checkpoint_data_response->body};
1027*14675a02SAndroid Build Coastguard Worker }
1028*14675a02SAndroid Build Coastguard Worker
1029*14675a02SAndroid Build Coastguard Worker // Convert a Resource proto into a UriOrInlineData object. Returns an
1030*14675a02SAndroid Build Coastguard Worker // `INVALID_ARGUMENT` error if the given `Resource` has the `uri` field set to
1031*14675a02SAndroid Build Coastguard Worker // an empty value, or an `UNIMPLEMENTED` error if the `Resource` has an unknown
1032*14675a02SAndroid Build Coastguard Worker // field set.
1033*14675a02SAndroid Build Coastguard Worker absl::StatusOr<UriOrInlineData>
ConvertResourceToUriOrInlineData(const GrpcFederatedProtocol::TaskResource & resource)1034*14675a02SAndroid Build Coastguard Worker GrpcFederatedProtocol::ConvertResourceToUriOrInlineData(
1035*14675a02SAndroid Build Coastguard Worker const GrpcFederatedProtocol::TaskResource& resource) {
1036*14675a02SAndroid Build Coastguard Worker // We need to support 3 states:
1037*14675a02SAndroid Build Coastguard Worker // - Inline data is available.
1038*14675a02SAndroid Build Coastguard Worker // - No inline data nor is there a URI. This should be treated as there being
1039*14675a02SAndroid Build Coastguard Worker // an 'empty' inline data.
1040*14675a02SAndroid Build Coastguard Worker // - No inline data is available but a URI is available.
1041*14675a02SAndroid Build Coastguard Worker if (!resource.has_uri) {
1042*14675a02SAndroid Build Coastguard Worker // If the URI field wasn't set, then we'll just use the inline data field
1043*14675a02SAndroid Build Coastguard Worker // (which will either be set or be empty).
1044*14675a02SAndroid Build Coastguard Worker //
1045*14675a02SAndroid Build Coastguard Worker // Note: this copies the data into the new absl::Cord. However, this Cord is
1046*14675a02SAndroid Build Coastguard Worker // then passed around all the way to fl_runner.cc without copying its data,
1047*14675a02SAndroid Build Coastguard Worker // so this is ultimately approx. as efficient as the non-HTTP resource code
1048*14675a02SAndroid Build Coastguard Worker // path where we also make a copy of the protobuf string into a new string
1049*14675a02SAndroid Build Coastguard Worker // which is then returned.
1050*14675a02SAndroid Build Coastguard Worker return UriOrInlineData::CreateInlineData(
1051*14675a02SAndroid Build Coastguard Worker absl::Cord(resource.data),
1052*14675a02SAndroid Build Coastguard Worker UriOrInlineData::InlineData::CompressionFormat::kUncompressed);
1053*14675a02SAndroid Build Coastguard Worker }
1054*14675a02SAndroid Build Coastguard Worker if (resource.uri.empty()) {
1055*14675a02SAndroid Build Coastguard Worker return absl::InvalidArgumentError(
1056*14675a02SAndroid Build Coastguard Worker "Resource uri must be non-empty when set");
1057*14675a02SAndroid Build Coastguard Worker }
1058*14675a02SAndroid Build Coastguard Worker return UriOrInlineData::CreateUri(resource.uri, resource.client_cache_id,
1059*14675a02SAndroid Build Coastguard Worker resource.max_age);
1060*14675a02SAndroid Build Coastguard Worker }
1061*14675a02SAndroid Build Coastguard Worker
GetNetworkStats()1062*14675a02SAndroid Build Coastguard Worker NetworkStats GrpcFederatedProtocol::GetNetworkStats() {
1063*14675a02SAndroid Build Coastguard Worker // Note: the `HttpClient` bandwidth stats are similar to the gRPC protocol's
1064*14675a02SAndroid Build Coastguard Worker // "chunking layer" stats, in that they reflect as closely as possible the
1065*14675a02SAndroid Build Coastguard Worker // amount of data sent on the wire.
1066*14675a02SAndroid Build Coastguard Worker return {.bytes_downloaded = grpc_bidi_stream_->ChunkingLayerBytesReceived() +
1067*14675a02SAndroid Build Coastguard Worker http_bytes_downloaded_,
1068*14675a02SAndroid Build Coastguard Worker .bytes_uploaded = grpc_bidi_stream_->ChunkingLayerBytesSent() +
1069*14675a02SAndroid Build Coastguard Worker http_bytes_uploaded_,
1070*14675a02SAndroid Build Coastguard Worker .network_duration = network_stopwatch_->GetTotalDuration()};
1071*14675a02SAndroid Build Coastguard Worker }
1072*14675a02SAndroid Build Coastguard Worker
1073*14675a02SAndroid Build Coastguard Worker } // namespace client
1074*14675a02SAndroid Build Coastguard Worker } // namespace fcp
1075