1 /* 2 * Copyright 2020 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef FCP_CLIENT_GRPC_FEDERATED_PROTOCOL_H_ 17 #define FCP_CLIENT_GRPC_FEDERATED_PROTOCOL_H_ 18 19 #include <cstdint> 20 #include <functional> 21 #include <memory> 22 #include <optional> 23 #include <string> 24 #include <utility> 25 #include <variant> 26 #include <vector> 27 28 #include "absl/container/flat_hash_map.h" 29 #include "absl/container/flat_hash_set.h" 30 #include "absl/container/node_hash_map.h" 31 #include "absl/random/random.h" 32 #include "absl/status/status.h" 33 #include "absl/status/statusor.h" 34 #include "absl/time/time.h" 35 #include "fcp/base/monitoring.h" 36 #include "fcp/base/wall_clock_stopwatch.h" 37 #include "fcp/client/cache/resource_cache.h" 38 #include "fcp/client/engine/engine.pb.h" 39 #include "fcp/client/event_publisher.h" 40 #include "fcp/client/federated_protocol.h" 41 #include "fcp/client/fl_runner.pb.h" 42 #include "fcp/client/flags.h" 43 #include "fcp/client/grpc_bidi_stream.h" 44 #include "fcp/client/http/http_client.h" 45 #include "fcp/client/http/in_memory_request_response.h" 46 #include "fcp/client/interruptible_runner.h" 47 #include "fcp/client/log_manager.h" 48 #include "fcp/client/secagg_runner.h" 49 #include "fcp/client/selector_context.pb.h" 50 #include "fcp/client/stats.h" 51 #include "fcp/protocol/grpc_chunked_bidi_stream.h" 52 #include "fcp/protos/federated_api.pb.h" 53 #include "fcp/protos/plan.pb.h" 54 #include "fcp/secagg/client/secagg_client.h" 55 56 namespace fcp { 57 namespace client { 58 59 // Implements a single session of the gRPC-based Federated Learning protocol. 60 class GrpcFederatedProtocol : public ::fcp::client::FederatedProtocol { 61 public: 62 GrpcFederatedProtocol( 63 EventPublisher* event_publisher, LogManager* log_manager, 64 std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory, 65 const Flags* flags, ::fcp::client::http::HttpClient* http_client, 66 const std::string& federated_service_uri, const std::string& api_key, 67 const std::string& test_cert_path, absl::string_view population_name, 68 absl::string_view retry_token, absl::string_view client_version, 69 absl::string_view attestation_measurement, 70 std::function<bool()> should_abort, 71 const InterruptibleRunner::TimingConfig& timing_config, 72 const int64_t grpc_channel_deadline_seconds, 73 cache::ResourceCache* resource_cache); 74 75 // Test c'tor. 76 GrpcFederatedProtocol( 77 EventPublisher* event_publisher, LogManager* log_manager, 78 std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory, 79 const Flags* flags, ::fcp::client::http::HttpClient* http_client, 80 std::unique_ptr<GrpcBidiStreamInterface> grpc_bidi_stream, 81 absl::string_view population_name, absl::string_view retry_token, 82 absl::string_view client_version, 83 absl::string_view attestation_measurement, 84 std::function<bool()> should_abort, absl::BitGen bit_gen, 85 const InterruptibleRunner::TimingConfig& timing_config, 86 cache::ResourceCache* resource_cache); 87 88 ~GrpcFederatedProtocol() override; 89 90 absl::StatusOr<::fcp::client::FederatedProtocol::EligibilityEvalCheckinResult> 91 EligibilityEvalCheckin(std::function<void(const EligibilityEvalTask&)> 92 payload_uris_received_callback) override; 93 94 void ReportEligibilityEvalError(absl::Status error_status) override; 95 96 absl::StatusOr<::fcp::client::FederatedProtocol::CheckinResult> Checkin( 97 const std::optional< 98 google::internal::federatedml::v2::TaskEligibilityInfo>& 99 task_eligibility_info, 100 std::function<void(const TaskAssignment&)> payload_uris_received_callback) 101 override; 102 103 absl::StatusOr<::fcp::client::FederatedProtocol::MultipleTaskAssignments> 104 PerformMultipleTaskAssignments( 105 const std::vector<std::string>& task_names) override; 106 107 absl::Status ReportCompleted( 108 ComputationResults results, absl::Duration plan_duration, 109 std::optional<std::string> aggregation_session_id) override; 110 111 absl::Status ReportNotCompleted( 112 engine::PhaseOutcome phase_outcome, absl::Duration plan_duration, 113 std::optional<std::string> aggregation_session_id) override; 114 115 google::internal::federatedml::v2::RetryWindow GetLatestRetryWindow() 116 override; 117 118 NetworkStats GetNetworkStats() override; 119 120 private: 121 // Internal implementation of reporting for use by ReportCompleted() and 122 // ReportNotCompleted(). 123 absl::Status Report(ComputationResults results, 124 engine::PhaseOutcome phase_outcome, 125 absl::Duration plan_duration); 126 absl::Status ReportInternal( 127 std::string tf_checkpoint, engine::PhaseOutcome phase_outcome, 128 absl::Duration plan_duration, 129 fcp::secagg::ClientToServerWrapperMessage* secagg_commit_message); 130 131 // Helper function to send a ClientStreamMessage. If sending did not succeed, 132 // closes the underlying grpc stream. If sending does succeed then it updates 133 // `bytes_uploaded_`. 134 absl::Status Send(google::internal::federatedml::v2::ClientStreamMessage* 135 client_stream_message); 136 137 // Helper function to receive a ServerStreamMessage. If receiving did not 138 // succeed, closes the underlying grpc stream. If receiving does succeed then 139 // it updates `bytes_downloaded_`. 140 absl::Status Receive(google::internal::federatedml::v2::ServerStreamMessage* 141 server_stream_message); 142 143 // Helper function to compose a ProtocolOptionsRequest for eligibility eval or 144 // regular checkin requests. 145 google::internal::federatedml::v2::ProtocolOptionsRequest 146 CreateProtocolOptionsRequest(bool should_ack_checkin) const; 147 148 // Helper function to compose and send an EligibilityEvalCheckinRequest to the 149 // server. 150 absl::Status SendEligibilityEvalCheckinRequest(); 151 152 // Helper function to compose and send a CheckinRequest to the server. 153 absl::Status SendCheckinRequest( 154 const std::optional< 155 google::internal::federatedml::v2::TaskEligibilityInfo>& 156 task_eligibility_info); 157 158 // Helper to receive + process a CheckinRequestAck message. 159 absl::Status ReceiveCheckinRequestAck(); 160 161 // Helper to receive + process an EligibilityEvalCheckinResponse message. 162 absl::StatusOr<EligibilityEvalCheckinResult> 163 ReceiveEligibilityEvalCheckinResponse( 164 absl::Time start_time, std::function<void(const EligibilityEvalTask&)> 165 payload_uris_received_callback); 166 167 // Helper to receive + process a CheckinResponse message. 168 absl::StatusOr<CheckinResult> ReceiveCheckinResponse( 169 absl::Time start_time, std::function<void(const TaskAssignment&)> 170 payload_uris_received_callback); 171 172 // Utility class for holding an absolute retry time and a corresponding retry 173 // token. 174 struct RetryTimeAndToken { 175 absl::Time retry_time; 176 std::string retry_token; 177 }; 178 // Helper to generate a RetryWindow from a given time and token. 179 google::internal::federatedml::v2::RetryWindow 180 GenerateRetryWindowFromRetryTimeAndToken(const RetryTimeAndToken& retry_info); 181 182 // Helper that moves to the given object state if the given status represents 183 // a permanent error. 184 void UpdateObjectStateIfPermanentError( 185 absl::Status status, ObjectState permanent_error_object_state); 186 187 // Utility struct to represent resource data coming from the gRPC protocol. 188 // A resource is either represented by a URI from which the data should be 189 // fetched (in which case `has_uri` is true and `uri` should not be empty), or 190 // is available as inline data (in which case `has_uri` is false and `data` 191 // may or may not be empty). 192 struct TaskResource { 193 bool has_uri; 194 const std::string& uri; 195 const std::string& data; 196 // The following fields will be set if the client should attempt to cache 197 // the resource. 198 const std::string& client_cache_id; 199 const absl::Duration max_age; 200 }; 201 // Represents the common set of resources a task may have. 202 struct TaskResources { 203 TaskResource plan; 204 TaskResource checkpoint; 205 }; 206 207 // Helper function for fetching the checkpoint/plan resources for an 208 // eligibility eval task or regular task. This function will return an error 209 // if either `TaskResource` represents an invalid state (e.g. if `has_uri && 210 // uri.empty()`). 211 absl::StatusOr<PlanAndCheckpointPayloads> FetchTaskResources( 212 TaskResources task_resources); 213 // Validates the given `TaskResource` and converts it to a `UriOrInlineData` 214 // object for use with the `FetchResourcesInMemory` utility method. 215 absl::StatusOr<::fcp::client::http::UriOrInlineData> 216 ConvertResourceToUriOrInlineData(const TaskResource& resource); 217 218 ObjectState object_state_; 219 EventPublisher* const event_publisher_; 220 LogManager* const log_manager_; 221 std::unique_ptr<SecAggRunnerFactory> secagg_runner_factory_; 222 const Flags* const flags_; 223 ::fcp::client::http::HttpClient* const http_client_; 224 std::unique_ptr<GrpcBidiStreamInterface> grpc_bidi_stream_; 225 std::unique_ptr<InterruptibleRunner> interruptible_runner_; 226 const std::string population_name_; 227 const std::string retry_token_; 228 const std::string client_version_; 229 const std::string attestation_measurement_; 230 std::function<absl::StatusOr<bool>()> should_abort_; 231 absl::BitGen bit_gen_; 232 // The set of canonical error codes that should be treated as 'permanent' 233 // errors. 234 absl::flat_hash_set<int32_t> federated_training_permanent_error_codes_; 235 int64_t http_bytes_downloaded_ = 0; 236 int64_t http_bytes_uploaded_ = 0; 237 std::unique_ptr<WallClockStopwatch> network_stopwatch_ = 238 WallClockStopwatch::Create(); 239 // Represents 2 absolute retry timestamps and their corresponding retry 240 // tokens, to use when the device is rejected or accepted. The retry 241 // timestamps will have been generated based on the retry windows specified in 242 // the server's CheckinRequestAck message and the time at which that message 243 // was received. 244 struct CheckinRequestAckInfo { 245 RetryTimeAndToken retry_info_if_rejected; 246 RetryTimeAndToken retry_info_if_accepted; 247 }; 248 // Represents the information received via the CheckinRequestAck message. 249 // This field will have an absent value until that message has been received. 250 std::optional<CheckinRequestAckInfo> checkin_request_ack_info_; 251 // The identifier of the task that was received in a CheckinResponse. Note 252 // that this does not refer to the identifier of the eligbility eval task that 253 // may have been received in an EligibilityEvalCheckinResponse. 254 std::string execution_phase_id_; 255 absl::flat_hash_map< 256 std::string, google::internal::federatedml::v2::SideChannelExecutionInfo> 257 side_channels_; 258 google::internal::federatedml::v2::SideChannelProtocolExecutionInfo 259 side_channel_protocol_execution_info_; 260 google::internal::federatedml::v2::SideChannelProtocolOptionsResponse 261 side_channel_protocol_options_response_; 262 // `nullptr` if the feature is disabled. 263 cache::ResourceCache* resource_cache_; 264 }; 265 266 } // namespace client 267 } // namespace fcp 268 269 #endif // FCP_CLIENT_GRPC_FEDERATED_PROTOCOL_H_ 270