xref: /aosp_15_r20/external/federated-compute/fcp/client/grpc_federated_protocol.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef FCP_CLIENT_GRPC_FEDERATED_PROTOCOL_H_
17 #define FCP_CLIENT_GRPC_FEDERATED_PROTOCOL_H_
18 
19 #include <cstdint>
20 #include <functional>
21 #include <memory>
22 #include <optional>
23 #include <string>
24 #include <utility>
25 #include <variant>
26 #include <vector>
27 
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "absl/container/node_hash_map.h"
31 #include "absl/random/random.h"
32 #include "absl/status/status.h"
33 #include "absl/status/statusor.h"
34 #include "absl/time/time.h"
35 #include "fcp/base/monitoring.h"
36 #include "fcp/base/wall_clock_stopwatch.h"
37 #include "fcp/client/cache/resource_cache.h"
38 #include "fcp/client/engine/engine.pb.h"
39 #include "fcp/client/event_publisher.h"
40 #include "fcp/client/federated_protocol.h"
41 #include "fcp/client/fl_runner.pb.h"
42 #include "fcp/client/flags.h"
43 #include "fcp/client/grpc_bidi_stream.h"
44 #include "fcp/client/http/http_client.h"
45 #include "fcp/client/http/in_memory_request_response.h"
46 #include "fcp/client/interruptible_runner.h"
47 #include "fcp/client/log_manager.h"
48 #include "fcp/client/secagg_runner.h"
49 #include "fcp/client/selector_context.pb.h"
50 #include "fcp/client/stats.h"
51 #include "fcp/protocol/grpc_chunked_bidi_stream.h"
52 #include "fcp/protos/federated_api.pb.h"
53 #include "fcp/protos/plan.pb.h"
54 #include "fcp/secagg/client/secagg_client.h"
55 
56 namespace fcp {
57 namespace client {
58 
59 // Implements a single session of the gRPC-based Federated Learning protocol.
60 class GrpcFederatedProtocol : public ::fcp::client::FederatedProtocol {
61  public:
62   GrpcFederatedProtocol(
63       EventPublisher* event_publisher, LogManager* log_manager,
64       std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory,
65       const Flags* flags, ::fcp::client::http::HttpClient* http_client,
66       const std::string& federated_service_uri, const std::string& api_key,
67       const std::string& test_cert_path, absl::string_view population_name,
68       absl::string_view retry_token, absl::string_view client_version,
69       absl::string_view attestation_measurement,
70       std::function<bool()> should_abort,
71       const InterruptibleRunner::TimingConfig& timing_config,
72       const int64_t grpc_channel_deadline_seconds,
73       cache::ResourceCache* resource_cache);
74 
75   // Test c'tor.
76   GrpcFederatedProtocol(
77       EventPublisher* event_publisher, LogManager* log_manager,
78       std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory,
79       const Flags* flags, ::fcp::client::http::HttpClient* http_client,
80       std::unique_ptr<GrpcBidiStreamInterface> grpc_bidi_stream,
81       absl::string_view population_name, absl::string_view retry_token,
82       absl::string_view client_version,
83       absl::string_view attestation_measurement,
84       std::function<bool()> should_abort, absl::BitGen bit_gen,
85       const InterruptibleRunner::TimingConfig& timing_config,
86       cache::ResourceCache* resource_cache);
87 
88   ~GrpcFederatedProtocol() override;
89 
90   absl::StatusOr<::fcp::client::FederatedProtocol::EligibilityEvalCheckinResult>
91   EligibilityEvalCheckin(std::function<void(const EligibilityEvalTask&)>
92                              payload_uris_received_callback) override;
93 
94   void ReportEligibilityEvalError(absl::Status error_status) override;
95 
96   absl::StatusOr<::fcp::client::FederatedProtocol::CheckinResult> Checkin(
97       const std::optional<
98           google::internal::federatedml::v2::TaskEligibilityInfo>&
99           task_eligibility_info,
100       std::function<void(const TaskAssignment&)> payload_uris_received_callback)
101       override;
102 
103   absl::StatusOr<::fcp::client::FederatedProtocol::MultipleTaskAssignments>
104   PerformMultipleTaskAssignments(
105       const std::vector<std::string>& task_names) override;
106 
107   absl::Status ReportCompleted(
108       ComputationResults results, absl::Duration plan_duration,
109       std::optional<std::string> aggregation_session_id) override;
110 
111   absl::Status ReportNotCompleted(
112       engine::PhaseOutcome phase_outcome, absl::Duration plan_duration,
113       std::optional<std::string> aggregation_session_id) override;
114 
115   google::internal::federatedml::v2::RetryWindow GetLatestRetryWindow()
116       override;
117 
118   NetworkStats GetNetworkStats() override;
119 
120  private:
121   // Internal implementation of reporting for use by ReportCompleted() and
122   // ReportNotCompleted().
123   absl::Status Report(ComputationResults results,
124                       engine::PhaseOutcome phase_outcome,
125                       absl::Duration plan_duration);
126   absl::Status ReportInternal(
127       std::string tf_checkpoint, engine::PhaseOutcome phase_outcome,
128       absl::Duration plan_duration,
129       fcp::secagg::ClientToServerWrapperMessage* secagg_commit_message);
130 
131   // Helper function to send a ClientStreamMessage. If sending did not succeed,
132   // closes the underlying grpc stream. If sending does succeed then it updates
133   // `bytes_uploaded_`.
134   absl::Status Send(google::internal::federatedml::v2::ClientStreamMessage*
135                         client_stream_message);
136 
137   // Helper function to receive a ServerStreamMessage. If receiving did not
138   // succeed, closes the underlying grpc stream. If receiving does succeed then
139   // it updates `bytes_downloaded_`.
140   absl::Status Receive(google::internal::federatedml::v2::ServerStreamMessage*
141                            server_stream_message);
142 
143   // Helper function to compose a ProtocolOptionsRequest for eligibility eval or
144   // regular checkin requests.
145   google::internal::federatedml::v2::ProtocolOptionsRequest
146   CreateProtocolOptionsRequest(bool should_ack_checkin) const;
147 
148   // Helper function to compose and send an EligibilityEvalCheckinRequest to the
149   // server.
150   absl::Status SendEligibilityEvalCheckinRequest();
151 
152   // Helper function to compose and send a CheckinRequest to the server.
153   absl::Status SendCheckinRequest(
154       const std::optional<
155           google::internal::federatedml::v2::TaskEligibilityInfo>&
156           task_eligibility_info);
157 
158   // Helper to receive + process a CheckinRequestAck message.
159   absl::Status ReceiveCheckinRequestAck();
160 
161   // Helper to receive + process an EligibilityEvalCheckinResponse message.
162   absl::StatusOr<EligibilityEvalCheckinResult>
163   ReceiveEligibilityEvalCheckinResponse(
164       absl::Time start_time, std::function<void(const EligibilityEvalTask&)>
165                                  payload_uris_received_callback);
166 
167   // Helper to receive + process a CheckinResponse message.
168   absl::StatusOr<CheckinResult> ReceiveCheckinResponse(
169       absl::Time start_time, std::function<void(const TaskAssignment&)>
170                                  payload_uris_received_callback);
171 
172   // Utility class for holding an absolute retry time and a corresponding retry
173   // token.
174   struct RetryTimeAndToken {
175     absl::Time retry_time;
176     std::string retry_token;
177   };
178   // Helper to generate a RetryWindow from a given time and token.
179   google::internal::federatedml::v2::RetryWindow
180   GenerateRetryWindowFromRetryTimeAndToken(const RetryTimeAndToken& retry_info);
181 
182   // Helper that moves to the given object state if the given status represents
183   // a permanent error.
184   void UpdateObjectStateIfPermanentError(
185       absl::Status status, ObjectState permanent_error_object_state);
186 
187   // Utility struct to represent resource data coming from the gRPC protocol.
188   // A resource is either represented by a URI from which the data should be
189   // fetched (in which case `has_uri` is true and `uri` should not be empty), or
190   // is available as inline data (in which case `has_uri` is false and `data`
191   // may or may not be empty).
192   struct TaskResource {
193     bool has_uri;
194     const std::string& uri;
195     const std::string& data;
196     // The following fields will be set if the client should attempt to cache
197     // the resource.
198     const std::string& client_cache_id;
199     const absl::Duration max_age;
200   };
201   // Represents the common set of resources a task may have.
202   struct TaskResources {
203     TaskResource plan;
204     TaskResource checkpoint;
205   };
206 
207   // Helper function for fetching the checkpoint/plan resources for an
208   // eligibility eval task or regular task. This function will return an error
209   // if either `TaskResource` represents an invalid state (e.g. if `has_uri &&
210   // uri.empty()`).
211   absl::StatusOr<PlanAndCheckpointPayloads> FetchTaskResources(
212       TaskResources task_resources);
213   // Validates the given `TaskResource` and converts it to a `UriOrInlineData`
214   // object for use with the `FetchResourcesInMemory` utility method.
215   absl::StatusOr<::fcp::client::http::UriOrInlineData>
216   ConvertResourceToUriOrInlineData(const TaskResource& resource);
217 
218   ObjectState object_state_;
219   EventPublisher* const event_publisher_;
220   LogManager* const log_manager_;
221   std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory_;
222   const Flags* const flags_;
223   ::fcp::client::http::HttpClient* const http_client_;
224   std::unique_ptr<GrpcBidiStreamInterface> grpc_bidi_stream_;
225   std::unique_ptr<InterruptibleRunner> interruptible_runner_;
226   const std::string population_name_;
227   const std::string retry_token_;
228   const std::string client_version_;
229   const std::string attestation_measurement_;
230   std::function<absl::StatusOr<bool>()> should_abort_;
231   absl::BitGen bit_gen_;
232   // The set of canonical error codes that should be treated as 'permanent'
233   // errors.
234   absl::flat_hash_set<int32_t> federated_training_permanent_error_codes_;
235   int64_t http_bytes_downloaded_ = 0;
236   int64_t http_bytes_uploaded_ = 0;
237   std::unique_ptr<WallClockStopwatch> network_stopwatch_ =
238       WallClockStopwatch::Create();
239   // Represents 2 absolute retry timestamps and their corresponding retry
240   // tokens, to use when the device is rejected or accepted. The retry
241   // timestamps will have been generated based on the retry windows specified in
242   // the server's CheckinRequestAck message and the time at which that message
243   // was received.
244   struct CheckinRequestAckInfo {
245     RetryTimeAndToken retry_info_if_rejected;
246     RetryTimeAndToken retry_info_if_accepted;
247   };
248   // Represents the information received via the CheckinRequestAck message.
249   // This field will have an absent value until that message has been received.
250   std::optional<CheckinRequestAckInfo> checkin_request_ack_info_;
251   // The identifier of the task that was received in a CheckinResponse. Note
252   // that this does not refer to the identifier of the eligbility eval task that
253   // may have been received in an EligibilityEvalCheckinResponse.
254   std::string execution_phase_id_;
255   absl::flat_hash_map<
256       std::string, google::internal::federatedml::v2::SideChannelExecutionInfo>
257       side_channels_;
258   google::internal::federatedml::v2::SideChannelProtocolExecutionInfo
259       side_channel_protocol_execution_info_;
260   google::internal::federatedml::v2::SideChannelProtocolOptionsResponse
261       side_channel_protocol_options_response_;
262   // `nullptr` if the feature is disabled.
263   cache::ResourceCache* resource_cache_;
264 };
265 
266 }  // namespace client
267 }  // namespace fcp
268 
269 #endif  // FCP_CLIENT_GRPC_FEDERATED_PROTOCOL_H_
270