xref: /aosp_15_r20/external/federated-compute/fcp/client/client_runner_main.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2020 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker 
17*14675a02SAndroid Build Coastguard Worker #include <fstream>
18*14675a02SAndroid Build Coastguard Worker #include <optional>
19*14675a02SAndroid Build Coastguard Worker #include <string>
20*14675a02SAndroid Build Coastguard Worker #include <utility>
21*14675a02SAndroid Build Coastguard Worker 
22*14675a02SAndroid Build Coastguard Worker 
23*14675a02SAndroid Build Coastguard Worker #include "absl/flags/flag.h"
24*14675a02SAndroid Build Coastguard Worker #include "absl/flags/parse.h"
25*14675a02SAndroid Build Coastguard Worker #include "absl/flags/usage.h"
26*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h"
27*14675a02SAndroid Build Coastguard Worker #include "absl/status/statusor.h"
28*14675a02SAndroid Build Coastguard Worker #include "absl/strings/str_split.h"
29*14675a02SAndroid Build Coastguard Worker #include "fcp/base/monitoring.h"
30*14675a02SAndroid Build Coastguard Worker #include "fcp/client/client_runner.h"
31*14675a02SAndroid Build Coastguard Worker #include "fcp/client/client_runner_example_data.pb.h"
32*14675a02SAndroid Build Coastguard Worker #include "fcp/client/fake_event_publisher.h"
33*14675a02SAndroid Build Coastguard Worker #include "fcp/client/fl_runner.h"
34*14675a02SAndroid Build Coastguard Worker 
35*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(std::string, server, "",
36*14675a02SAndroid Build Coastguard Worker           "Federated Server URI (supports https+test:// and https:// URIs");
37*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(std::string, api_key, "", "API Key");
38*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(std::string, test_cert, "",
39*14675a02SAndroid Build Coastguard Worker           "Path to test CA certificate PEM file; used for https+test:// URIs");
40*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(std::string, session, "", "Session name");
41*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(std::string, population, "", "Population name");
42*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(std::string, retry_token, "", "Retry token");
43*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(std::string, client_version, "", "Client version");
44*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(std::string, attestation_string, "", "Attestation string");
45*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(std::string, example_data_path, "",
46*14675a02SAndroid Build Coastguard Worker           "Path to a serialized ClientRunnerExampleData proto with client "
47*14675a02SAndroid Build Coastguard Worker           "example data. Falls back to --num_empty_examples if unset.");
48*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(int, num_empty_examples, 0,
49*14675a02SAndroid Build Coastguard Worker           "Number of (empty) examples each created iterator serves. Ignored if "
50*14675a02SAndroid Build Coastguard Worker           "--example_store_path is set.");
51*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(int, num_rounds, 1, "Number of rounds to train");
52*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(int, sleep_after_round_secs, 3,
53*14675a02SAndroid Build Coastguard Worker           "Number of seconds to sleep after each round.");
54*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(bool, use_http_federated_compute_protocol, false,
55*14675a02SAndroid Build Coastguard Worker           "Whether to enable the HTTP FederatedCompute protocol instead "
56*14675a02SAndroid Build Coastguard Worker           "of the gRPC FederatedTrainingApi protocol.");
57*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(bool, use_tflite_training, false, "Whether use TFLite for training.");
58*14675a02SAndroid Build Coastguard Worker 
59*14675a02SAndroid Build Coastguard Worker static constexpr char kUsageString[] =
60*14675a02SAndroid Build Coastguard Worker     "Stand-alone Federated Client Executable.\n\n"
61*14675a02SAndroid Build Coastguard Worker     "Connects to the specified server, tries to retrieve a plan, run the\n"
62*14675a02SAndroid Build Coastguard Worker     "plan (feeding the specified number of empty examples), and report the\n"
63*14675a02SAndroid Build Coastguard Worker     "results of the computation back to the server.";
64*14675a02SAndroid Build Coastguard Worker 
LoadExampleData(const std::string & examples_path)65*14675a02SAndroid Build Coastguard Worker static absl::StatusOr<fcp::client::ClientRunnerExampleData> LoadExampleData(
66*14675a02SAndroid Build Coastguard Worker     const std::string& examples_path) {
67*14675a02SAndroid Build Coastguard Worker   std::ifstream examples_file(examples_path);
68*14675a02SAndroid Build Coastguard Worker   fcp::client::ClientRunnerExampleData data;
69*14675a02SAndroid Build Coastguard Worker   if (!data.ParseFromIstream(&examples_file) || !examples_file.eof()) {
70*14675a02SAndroid Build Coastguard Worker     return absl::InvalidArgumentError(
71*14675a02SAndroid Build Coastguard Worker         "Failed to parse ClientRunnerExampleData");
72*14675a02SAndroid Build Coastguard Worker   }
73*14675a02SAndroid Build Coastguard Worker   return data;
74*14675a02SAndroid Build Coastguard Worker }
75*14675a02SAndroid Build Coastguard Worker 
main(int argc,char ** argv)76*14675a02SAndroid Build Coastguard Worker int main(int argc, char** argv) {
77*14675a02SAndroid Build Coastguard Worker   absl::SetProgramUsageMessage(kUsageString);
78*14675a02SAndroid Build Coastguard Worker   absl::ParseCommandLine(argc, argv);
79*14675a02SAndroid Build Coastguard Worker 
80*14675a02SAndroid Build Coastguard Worker   int num_rounds = absl::GetFlag(FLAGS_num_rounds);
81*14675a02SAndroid Build Coastguard Worker   std::string server = absl::GetFlag(FLAGS_server);
82*14675a02SAndroid Build Coastguard Worker   std::string session = absl::GetFlag(FLAGS_session);
83*14675a02SAndroid Build Coastguard Worker   std::string population = absl::GetFlag(FLAGS_population);
84*14675a02SAndroid Build Coastguard Worker   std::string client_version = absl::GetFlag(FLAGS_client_version);
85*14675a02SAndroid Build Coastguard Worker   std::string test_cert = absl::GetFlag(FLAGS_test_cert);
86*14675a02SAndroid Build Coastguard Worker   FCP_LOG(INFO) << "Running for " << num_rounds << " rounds:";
87*14675a02SAndroid Build Coastguard Worker   FCP_LOG(INFO) << " - server:         " << server;
88*14675a02SAndroid Build Coastguard Worker   FCP_LOG(INFO) << " - session:        " << session;
89*14675a02SAndroid Build Coastguard Worker   FCP_LOG(INFO) << " - population:     " << population;
90*14675a02SAndroid Build Coastguard Worker   FCP_LOG(INFO) << " - client_version: " << client_version;
91*14675a02SAndroid Build Coastguard Worker 
92*14675a02SAndroid Build Coastguard Worker   std::optional<fcp::client::ClientRunnerExampleData> example_data;
93*14675a02SAndroid Build Coastguard Worker   if (std::string path = absl::GetFlag(FLAGS_example_data_path);
94*14675a02SAndroid Build Coastguard Worker       !path.empty()) {
95*14675a02SAndroid Build Coastguard Worker     auto statusor = LoadExampleData(path);
96*14675a02SAndroid Build Coastguard Worker     if (!statusor.ok()) {
97*14675a02SAndroid Build Coastguard Worker       FCP_LOG(ERROR) << "Failed to load example data: " << statusor.status();
98*14675a02SAndroid Build Coastguard Worker       return 1;
99*14675a02SAndroid Build Coastguard Worker     }
100*14675a02SAndroid Build Coastguard Worker     example_data = *std::move(statusor);
101*14675a02SAndroid Build Coastguard Worker   }
102*14675a02SAndroid Build Coastguard Worker 
103*14675a02SAndroid Build Coastguard Worker   bool success = false;
104*14675a02SAndroid Build Coastguard Worker   for (auto i = 0; i < num_rounds || num_rounds < 0; ++i) {
105*14675a02SAndroid Build Coastguard Worker     fcp::client::FederatedTaskEnvDepsImpl federated_task_env_deps_impl =
106*14675a02SAndroid Build Coastguard Worker         example_data
107*14675a02SAndroid Build Coastguard Worker             ? fcp::client::FederatedTaskEnvDepsImpl(*example_data, test_cert)
108*14675a02SAndroid Build Coastguard Worker             : fcp::client::FederatedTaskEnvDepsImpl(
109*14675a02SAndroid Build Coastguard Worker                   absl::GetFlag(FLAGS_num_empty_examples), test_cert);
110*14675a02SAndroid Build Coastguard Worker     fcp::client::FakeEventPublisher event_publisher(/*quiet=*/false);
111*14675a02SAndroid Build Coastguard Worker     fcp::client::FilesImpl files_impl;
112*14675a02SAndroid Build Coastguard Worker     fcp::client::LogManagerImpl log_manager_impl;
113*14675a02SAndroid Build Coastguard Worker     fcp::client::FlagsImpl flags;
114*14675a02SAndroid Build Coastguard Worker     flags.set_use_http_federated_compute_protocol(
115*14675a02SAndroid Build Coastguard Worker         absl::GetFlag(FLAGS_use_http_federated_compute_protocol));
116*14675a02SAndroid Build Coastguard Worker     flags.set_use_tflite_training(absl::GetFlag(FLAGS_use_tflite_training));
117*14675a02SAndroid Build Coastguard Worker 
118*14675a02SAndroid Build Coastguard Worker     auto fl_runner_result = RunFederatedComputation(
119*14675a02SAndroid Build Coastguard Worker         &federated_task_env_deps_impl, &event_publisher, &files_impl,
120*14675a02SAndroid Build Coastguard Worker         &log_manager_impl, &flags, server, absl::GetFlag(FLAGS_api_key),
121*14675a02SAndroid Build Coastguard Worker         test_cert, session, population, absl::GetFlag(FLAGS_retry_token),
122*14675a02SAndroid Build Coastguard Worker         client_version, absl::GetFlag(FLAGS_attestation_string));
123*14675a02SAndroid Build Coastguard Worker     if (fl_runner_result.ok()) {
124*14675a02SAndroid Build Coastguard Worker       FCP_LOG(INFO) << "Run finished successfully; result: "
125*14675a02SAndroid Build Coastguard Worker                     << fl_runner_result.value().DebugString();
126*14675a02SAndroid Build Coastguard Worker       success = true;
127*14675a02SAndroid Build Coastguard Worker     } else {
128*14675a02SAndroid Build Coastguard Worker       FCP_LOG(ERROR) << "Error during run: " << fl_runner_result.status();
129*14675a02SAndroid Build Coastguard Worker     }
130*14675a02SAndroid Build Coastguard Worker     int sleep_secs = absl::GetFlag(FLAGS_sleep_after_round_secs);
131*14675a02SAndroid Build Coastguard Worker     FCP_LOG(INFO) << "Sleeping for " << sleep_secs << " secs";
132*14675a02SAndroid Build Coastguard Worker     absl::SleepFor(absl::Seconds(sleep_secs));
133*14675a02SAndroid Build Coastguard Worker   }
134*14675a02SAndroid Build Coastguard Worker   return success ? 0 : 1;
135*14675a02SAndroid Build Coastguard Worker }
136