1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker * Copyright 2020 Google LLC
3*14675a02SAndroid Build Coastguard Worker *
4*14675a02SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker *
8*14675a02SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker *
10*14675a02SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker */
16*14675a02SAndroid Build Coastguard Worker
17*14675a02SAndroid Build Coastguard Worker #include <fstream>
18*14675a02SAndroid Build Coastguard Worker #include <optional>
19*14675a02SAndroid Build Coastguard Worker #include <string>
20*14675a02SAndroid Build Coastguard Worker #include <utility>
21*14675a02SAndroid Build Coastguard Worker
22*14675a02SAndroid Build Coastguard Worker
23*14675a02SAndroid Build Coastguard Worker #include "absl/flags/flag.h"
24*14675a02SAndroid Build Coastguard Worker #include "absl/flags/parse.h"
25*14675a02SAndroid Build Coastguard Worker #include "absl/flags/usage.h"
26*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h"
27*14675a02SAndroid Build Coastguard Worker #include "absl/status/statusor.h"
28*14675a02SAndroid Build Coastguard Worker #include "absl/strings/str_split.h"
29*14675a02SAndroid Build Coastguard Worker #include "fcp/base/monitoring.h"
30*14675a02SAndroid Build Coastguard Worker #include "fcp/client/client_runner.h"
31*14675a02SAndroid Build Coastguard Worker #include "fcp/client/client_runner_example_data.pb.h"
32*14675a02SAndroid Build Coastguard Worker #include "fcp/client/fake_event_publisher.h"
33*14675a02SAndroid Build Coastguard Worker #include "fcp/client/fl_runner.h"
34*14675a02SAndroid Build Coastguard Worker
35*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(std::string, server, "",
36*14675a02SAndroid Build Coastguard Worker "Federated Server URI (supports https+test:// and https:// URIs");
37*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(std::string, api_key, "", "API Key");
38*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(std::string, test_cert, "",
39*14675a02SAndroid Build Coastguard Worker "Path to test CA certificate PEM file; used for https+test:// URIs");
40*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(std::string, session, "", "Session name");
41*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(std::string, population, "", "Population name");
42*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(std::string, retry_token, "", "Retry token");
43*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(std::string, client_version, "", "Client version");
44*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(std::string, attestation_string, "", "Attestation string");
45*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(std::string, example_data_path, "",
46*14675a02SAndroid Build Coastguard Worker "Path to a serialized ClientRunnerExampleData proto with client "
47*14675a02SAndroid Build Coastguard Worker "example data. Falls back to --num_empty_examples if unset.");
48*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(int, num_empty_examples, 0,
49*14675a02SAndroid Build Coastguard Worker "Number of (empty) examples each created iterator serves. Ignored if "
50*14675a02SAndroid Build Coastguard Worker "--example_store_path is set.");
51*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(int, num_rounds, 1, "Number of rounds to train");
52*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(int, sleep_after_round_secs, 3,
53*14675a02SAndroid Build Coastguard Worker "Number of seconds to sleep after each round.");
54*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(bool, use_http_federated_compute_protocol, false,
55*14675a02SAndroid Build Coastguard Worker "Whether to enable the HTTP FederatedCompute protocol instead "
56*14675a02SAndroid Build Coastguard Worker "of the gRPC FederatedTrainingApi protocol.");
57*14675a02SAndroid Build Coastguard Worker ABSL_FLAG(bool, use_tflite_training, false, "Whether use TFLite for training.");
58*14675a02SAndroid Build Coastguard Worker
59*14675a02SAndroid Build Coastguard Worker static constexpr char kUsageString[] =
60*14675a02SAndroid Build Coastguard Worker "Stand-alone Federated Client Executable.\n\n"
61*14675a02SAndroid Build Coastguard Worker "Connects to the specified server, tries to retrieve a plan, run the\n"
62*14675a02SAndroid Build Coastguard Worker "plan (feeding the specified number of empty examples), and report the\n"
63*14675a02SAndroid Build Coastguard Worker "results of the computation back to the server.";
64*14675a02SAndroid Build Coastguard Worker
LoadExampleData(const std::string & examples_path)65*14675a02SAndroid Build Coastguard Worker static absl::StatusOr<fcp::client::ClientRunnerExampleData> LoadExampleData(
66*14675a02SAndroid Build Coastguard Worker const std::string& examples_path) {
67*14675a02SAndroid Build Coastguard Worker std::ifstream examples_file(examples_path);
68*14675a02SAndroid Build Coastguard Worker fcp::client::ClientRunnerExampleData data;
69*14675a02SAndroid Build Coastguard Worker if (!data.ParseFromIstream(&examples_file) || !examples_file.eof()) {
70*14675a02SAndroid Build Coastguard Worker return absl::InvalidArgumentError(
71*14675a02SAndroid Build Coastguard Worker "Failed to parse ClientRunnerExampleData");
72*14675a02SAndroid Build Coastguard Worker }
73*14675a02SAndroid Build Coastguard Worker return data;
74*14675a02SAndroid Build Coastguard Worker }
75*14675a02SAndroid Build Coastguard Worker
main(int argc,char ** argv)76*14675a02SAndroid Build Coastguard Worker int main(int argc, char** argv) {
77*14675a02SAndroid Build Coastguard Worker absl::SetProgramUsageMessage(kUsageString);
78*14675a02SAndroid Build Coastguard Worker absl::ParseCommandLine(argc, argv);
79*14675a02SAndroid Build Coastguard Worker
80*14675a02SAndroid Build Coastguard Worker int num_rounds = absl::GetFlag(FLAGS_num_rounds);
81*14675a02SAndroid Build Coastguard Worker std::string server = absl::GetFlag(FLAGS_server);
82*14675a02SAndroid Build Coastguard Worker std::string session = absl::GetFlag(FLAGS_session);
83*14675a02SAndroid Build Coastguard Worker std::string population = absl::GetFlag(FLAGS_population);
84*14675a02SAndroid Build Coastguard Worker std::string client_version = absl::GetFlag(FLAGS_client_version);
85*14675a02SAndroid Build Coastguard Worker std::string test_cert = absl::GetFlag(FLAGS_test_cert);
86*14675a02SAndroid Build Coastguard Worker FCP_LOG(INFO) << "Running for " << num_rounds << " rounds:";
87*14675a02SAndroid Build Coastguard Worker FCP_LOG(INFO) << " - server: " << server;
88*14675a02SAndroid Build Coastguard Worker FCP_LOG(INFO) << " - session: " << session;
89*14675a02SAndroid Build Coastguard Worker FCP_LOG(INFO) << " - population: " << population;
90*14675a02SAndroid Build Coastguard Worker FCP_LOG(INFO) << " - client_version: " << client_version;
91*14675a02SAndroid Build Coastguard Worker
92*14675a02SAndroid Build Coastguard Worker std::optional<fcp::client::ClientRunnerExampleData> example_data;
93*14675a02SAndroid Build Coastguard Worker if (std::string path = absl::GetFlag(FLAGS_example_data_path);
94*14675a02SAndroid Build Coastguard Worker !path.empty()) {
95*14675a02SAndroid Build Coastguard Worker auto statusor = LoadExampleData(path);
96*14675a02SAndroid Build Coastguard Worker if (!statusor.ok()) {
97*14675a02SAndroid Build Coastguard Worker FCP_LOG(ERROR) << "Failed to load example data: " << statusor.status();
98*14675a02SAndroid Build Coastguard Worker return 1;
99*14675a02SAndroid Build Coastguard Worker }
100*14675a02SAndroid Build Coastguard Worker example_data = *std::move(statusor);
101*14675a02SAndroid Build Coastguard Worker }
102*14675a02SAndroid Build Coastguard Worker
103*14675a02SAndroid Build Coastguard Worker bool success = false;
104*14675a02SAndroid Build Coastguard Worker for (auto i = 0; i < num_rounds || num_rounds < 0; ++i) {
105*14675a02SAndroid Build Coastguard Worker fcp::client::FederatedTaskEnvDepsImpl federated_task_env_deps_impl =
106*14675a02SAndroid Build Coastguard Worker example_data
107*14675a02SAndroid Build Coastguard Worker ? fcp::client::FederatedTaskEnvDepsImpl(*example_data, test_cert)
108*14675a02SAndroid Build Coastguard Worker : fcp::client::FederatedTaskEnvDepsImpl(
109*14675a02SAndroid Build Coastguard Worker absl::GetFlag(FLAGS_num_empty_examples), test_cert);
110*14675a02SAndroid Build Coastguard Worker fcp::client::FakeEventPublisher event_publisher(/*quiet=*/false);
111*14675a02SAndroid Build Coastguard Worker fcp::client::FilesImpl files_impl;
112*14675a02SAndroid Build Coastguard Worker fcp::client::LogManagerImpl log_manager_impl;
113*14675a02SAndroid Build Coastguard Worker fcp::client::FlagsImpl flags;
114*14675a02SAndroid Build Coastguard Worker flags.set_use_http_federated_compute_protocol(
115*14675a02SAndroid Build Coastguard Worker absl::GetFlag(FLAGS_use_http_federated_compute_protocol));
116*14675a02SAndroid Build Coastguard Worker flags.set_use_tflite_training(absl::GetFlag(FLAGS_use_tflite_training));
117*14675a02SAndroid Build Coastguard Worker
118*14675a02SAndroid Build Coastguard Worker auto fl_runner_result = RunFederatedComputation(
119*14675a02SAndroid Build Coastguard Worker &federated_task_env_deps_impl, &event_publisher, &files_impl,
120*14675a02SAndroid Build Coastguard Worker &log_manager_impl, &flags, server, absl::GetFlag(FLAGS_api_key),
121*14675a02SAndroid Build Coastguard Worker test_cert, session, population, absl::GetFlag(FLAGS_retry_token),
122*14675a02SAndroid Build Coastguard Worker client_version, absl::GetFlag(FLAGS_attestation_string));
123*14675a02SAndroid Build Coastguard Worker if (fl_runner_result.ok()) {
124*14675a02SAndroid Build Coastguard Worker FCP_LOG(INFO) << "Run finished successfully; result: "
125*14675a02SAndroid Build Coastguard Worker << fl_runner_result.value().DebugString();
126*14675a02SAndroid Build Coastguard Worker success = true;
127*14675a02SAndroid Build Coastguard Worker } else {
128*14675a02SAndroid Build Coastguard Worker FCP_LOG(ERROR) << "Error during run: " << fl_runner_result.status();
129*14675a02SAndroid Build Coastguard Worker }
130*14675a02SAndroid Build Coastguard Worker int sleep_secs = absl::GetFlag(FLAGS_sleep_after_round_secs);
131*14675a02SAndroid Build Coastguard Worker FCP_LOG(INFO) << "Sleeping for " << sleep_secs << " secs";
132*14675a02SAndroid Build Coastguard Worker absl::SleepFor(absl::Seconds(sleep_secs));
133*14675a02SAndroid Build Coastguard Worker }
134*14675a02SAndroid Build Coastguard Worker return success ? 0 : 1;
135*14675a02SAndroid Build Coastguard Worker }
136