xref: /aosp_15_r20/external/federated-compute/fcp/client/lc_runner.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "fcp/client/lc_runner.h"
17 
18 #include <functional>
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/status/status.h"
27 #include "absl/status/statusor.h"
28 #include "absl/time/time.h"
29 #include "fcp/base/monitoring.h"
30 #include "fcp/base/platform.h"
31 #include "fcp/client/engine/example_iterator_factory.h"
32 #include "fcp/client/engine/plan_engine_helpers.h"
33 
34 #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
35 #include "fcp/client/engine/simple_plan_engine.h"
36 #endif
37 
38 #include "fcp/client/engine/tflite_plan_engine.h"
39 #include "fcp/client/opstats/opstats_example_store.h"
40 #include "fcp/client/phase_logger_impl.h"
41 #include "fcp/client/selector_context.pb.h"
42 #include "fcp/protos/plan.pb.h"
43 #include "tensorflow/core/framework/tensor.h"
44 #include "tensorflow/core/framework/tensor.pb.h"
45 #include "tensorflow/core/framework/tensor_shape.pb.h"
46 #include "tensorflow/core/protobuf/struct.pb.h"
47 
48 namespace fcp {
49 namespace client {
50 
51 using ::fcp::client::opstats::OpStatsLogger;
52 using ::google::internal::federated::plan::ClientOnlyPlan;
53 using ::google::internal::federated::plan::LocalComputeIORouter;
54 
55 using TfLiteInputs = absl::flat_hash_map<std::string, std::string>;
56 using TfMobileInputs = std::vector<std::pair<std::string, tensorflow::Tensor>>;
57 
58 namespace {
59 #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
60 absl::StatusOr<std::unique_ptr<TfMobileInputs>>
ConstructInputsForTensorflowSpecPlan(const LocalComputeIORouter & local_compute,const std::string & input_dir_uri,const std::string & output_dir_uri,const absl::flat_hash_map<std::string,std::string> & input_resources)61 ConstructInputsForTensorflowSpecPlan(
62     const LocalComputeIORouter& local_compute, const std::string& input_dir_uri,
63     const std::string& output_dir_uri,
64     const absl::flat_hash_map<std::string, std::string>& input_resources) {
65   auto inputs = std::make_unique<
66       std::vector<std::pair<std::string, tensorflow::Tensor>>>();
67   if (local_compute.has_multiple_input_resources()) {
68     if (!input_dir_uri.empty()) {
69       return absl::InvalidArgumentError(
70           "Both input dir and input resources are provided.");
71     }
72     auto input_resource_tensor_name_map =
73         local_compute.multiple_input_resources()
74             .input_resource_tensor_name_map();
75     for (const auto& resource : input_resources) {
76       tensorflow::Tensor resource_tensor(tensorflow::DT_STRING, {});
77       resource_tensor.scalar<tensorflow::tstring>()() = resource.second;
78       if (!input_resource_tensor_name_map.contains(resource.first)) {
79         return absl::InvalidArgumentError(
80             absl::StrCat("User provided input resource:", resource.first,
81                          " is missing in LocalComputeIORouter."));
82       }
83       std::string tensor_name = input_resource_tensor_name_map[resource.first];
84       inputs->push_back({tensor_name, resource_tensor});
85     }
86   } else {
87     tensorflow::Tensor input_dirpath(tensorflow::DT_STRING, {});
88     input_dirpath.scalar<tensorflow::tstring>()() = input_dir_uri;
89     inputs->push_back({local_compute.input_dir_tensor_name(), input_dirpath});
90   }
91   tensorflow::Tensor output_dirpath(tensorflow::DT_STRING, {});
92   output_dirpath.scalar<tensorflow::tstring>()() = output_dir_uri;
93   inputs->push_back({local_compute.output_dir_tensor_name(), output_dirpath});
94   return inputs;
95 }
96 #endif
97 
ConstructInputsForTFLitePlan(const LocalComputeIORouter & local_compute,const std::string & input_dir_uri,const std::string & output_dir_uri,const absl::flat_hash_map<std::string,std::string> & input_resources)98 absl::StatusOr<std::unique_ptr<TfLiteInputs>> ConstructInputsForTFLitePlan(
99     const LocalComputeIORouter& local_compute, const std::string& input_dir_uri,
100     const std::string& output_dir_uri,
101     const absl::flat_hash_map<std::string, std::string>& input_resources) {
102   auto inputs = std::make_unique<TfLiteInputs>();
103   if (local_compute.has_multiple_input_resources()) {
104     if (!input_dir_uri.empty()) {
105       return absl::InvalidArgumentError(
106           "Both input dir and input resources are provided.");
107     }
108     auto input_resource_tensor_name_map =
109         local_compute.multiple_input_resources()
110             .input_resource_tensor_name_map();
111     for (const auto& resource : input_resources) {
112       if (!input_resource_tensor_name_map.contains(resource.first)) {
113         // If the user provided more input resources than required in the
114         // LocalComputeIORouter, we simply continue without throwing an error.
115         // In this way, the user could update their scheduling logic separately
116         // from their local computation definitions.
117         continue;
118       }
119       std::string tensor_name = input_resource_tensor_name_map[resource.first];
120       (*inputs)[tensor_name] = resource.second;
121     }
122   } else {
123     (*inputs)[local_compute.input_dir_tensor_name()] = input_dir_uri;
124   }
125   (*inputs)[local_compute.output_dir_tensor_name()] = output_dir_uri;
126   return inputs;
127 }
128 
LogComputationOutcome(engine::PlanResult plan_result,PhaseLogger & phase_logger,absl::Time run_plan_start_time,absl::Time reference_time)129 void LogComputationOutcome(engine::PlanResult plan_result,
130                            PhaseLogger& phase_logger,
131                            absl::Time run_plan_start_time,
132                            absl::Time reference_time) {
133   switch (plan_result.outcome) {
134     case engine::PlanOutcome::kSuccess:
135       phase_logger.LogComputationCompleted(plan_result.example_stats,
136                                            NetworkStats(), run_plan_start_time,
137                                            reference_time);
138       break;
139     case engine::PlanOutcome::kInterrupted:
140       phase_logger.LogComputationInterrupted(
141           plan_result.original_status, plan_result.example_stats,
142           NetworkStats(), run_plan_start_time, reference_time);
143       break;
144     case engine::PlanOutcome::kInvalidArgument:
145       phase_logger.LogComputationInvalidArgument(
146           plan_result.original_status, plan_result.example_stats,
147           NetworkStats(), run_plan_start_time);
148       break;
149     case engine::PlanOutcome::kTensorflowError:
150       phase_logger.LogComputationTensorflowError(
151           std::move(plan_result.original_status), plan_result.example_stats,
152           NetworkStats(), run_plan_start_time, reference_time);
153       break;
154     case engine::PlanOutcome::kExampleIteratorError:
155       phase_logger.LogComputationExampleIteratorError(
156           plan_result.original_status, plan_result.example_stats,
157           NetworkStats(), run_plan_start_time);
158       break;
159   }
160 }
161 
162 // Creates an ExampleIteratorFactory that routes queries to the
163 // SimpleTaskEnvironment::CreateExampleIterator() method.
164 std::unique_ptr<engine::ExampleIteratorFactory>
CreateSimpleTaskEnvironmentIteratorFactory(SimpleTaskEnvironment * task_env,const SelectorContext & selector_context)165 CreateSimpleTaskEnvironmentIteratorFactory(
166     SimpleTaskEnvironment* task_env, const SelectorContext& selector_context) {
167   return std::make_unique<engine::FunctionalExampleIteratorFactory>(
168       /*can_handle_func=*/
169       [](const google::internal::federated::plan::ExampleSelector&) {
170         // The SimpleTaskEnvironment-based ExampleIteratorFactory should
171         // be the catch-all factory that is able to handle all queries
172         // that no other ExampleIteratorFactory is able to handle.
173         return true;
174       },
175       /*create_iterator_func=*/
176       [task_env, selector_context](
177           const google::internal::federated::plan::ExampleSelector&
178               example_selector) {
179         return task_env->CreateExampleIterator(example_selector,
180                                                selector_context);
181       },
182       /*should_collect_stats=*/true);
183 }
184 
RunPlanWithTensorflowSpec(PhaseLogger & phase_logger,std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,std::function<bool ()> should_abort,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,const ClientOnlyPlan & client_plan,const std::string & input_dir_uri,const std::string & output_dir_uri,const absl::flat_hash_map<std::string,std::string> & input_resources,const fcp::client::InterruptibleRunner::TimingConfig & timing_config,const absl::Time run_plan_start_time,const absl::Time reference_time)185 absl::Status RunPlanWithTensorflowSpec(
186     PhaseLogger& phase_logger,
187     std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
188     std::function<bool()> should_abort, LogManager* log_manager,
189     OpStatsLogger* opstats_logger, const Flags* flags,
190     const ClientOnlyPlan& client_plan, const std::string& input_dir_uri,
191     const std::string& output_dir_uri,
192     const absl::flat_hash_map<std::string, std::string>& input_resources,
193     const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
194     const absl::Time run_plan_start_time, const absl::Time reference_time) {
195   // Check that this is a TensorflowSpec-based plan for local computation.
196   if (!client_plan.phase().has_tensorflow_spec()) {
197     absl::Status error_status =
198         absl::InvalidArgumentError("Plan without TensorflowSpec");
199     phase_logger.LogComputationInvalidArgument(
200         error_status, ExampleStats(), NetworkStats(), run_plan_start_time);
201     return error_status;
202   }
203   if (!client_plan.phase().has_local_compute()) {
204     absl::Status error_status =
205         absl::InvalidArgumentError("Invalid TensorflowSpec-based plan");
206     phase_logger.LogComputationInvalidArgument(
207         error_status, ExampleStats(), NetworkStats(), run_plan_start_time);
208     return error_status;
209   }
210 
211   // Run plan
212   std::vector<std::string> output_names_unused;
213 
214   if (!client_plan.tflite_graph().empty()) {
215     log_manager->LogDiag(
216         ProdDiagCode::BACKGROUND_TRAINING_TFLITE_MODEL_INCLUDED);
217   }
218 
219   if (flags->use_tflite_training() && !client_plan.tflite_graph().empty()) {
220     auto inputs = ConstructInputsForTFLitePlan(
221         client_plan.phase().local_compute(), input_dir_uri, output_dir_uri,
222         input_resources);
223     if (!inputs.ok()) {
224       phase_logger.LogComputationInvalidArgument(
225           inputs.status(), ExampleStats(), NetworkStats(), run_plan_start_time);
226       return inputs.status();
227     }
228     engine::TfLitePlanEngine plan_engine(example_iterator_factories,
229                                          should_abort, log_manager,
230                                          opstats_logger, flags, &timing_config);
231     engine::PlanResult plan_result = plan_engine.RunPlan(
232         client_plan.phase().tensorflow_spec(), client_plan.tflite_graph(),
233         std::move(*inputs), output_names_unused);
234     engine::PlanOutcome outcome = plan_result.outcome;
235     LogComputationOutcome(std::move(plan_result), phase_logger,
236                           run_plan_start_time, reference_time);
237     return ConvertPlanOutcomeToStatus(outcome);
238   }
239 
240 #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
241   // Construct input tensors based on the values in the LocalComputeIORouter
242   // message.
243   auto inputs = ConstructInputsForTensorflowSpecPlan(
244       client_plan.phase().local_compute(), input_dir_uri, output_dir_uri,
245       input_resources);
246   if (!inputs.ok()) {
247     phase_logger.LogComputationInvalidArgument(
248         inputs.status(), ExampleStats(), NetworkStats(), run_plan_start_time);
249     return inputs.status();
250   }
251   engine::SimplePlanEngine plan_engine(
252       example_iterator_factories, should_abort, log_manager, opstats_logger,
253       &timing_config, flags->support_constant_tf_inputs());
254   engine::PlanResult plan_result = plan_engine.RunPlan(
255       client_plan.phase().tensorflow_spec(), client_plan.graph(),
256       client_plan.tensorflow_config_proto(), std::move(*inputs),
257       output_names_unused);
258   engine::PlanOutcome outcome = plan_result.outcome;
259   LogComputationOutcome(std::move(plan_result), phase_logger,
260                         run_plan_start_time, reference_time);
261   return ConvertPlanOutcomeToStatus(outcome);
262 #else
263   return absl::InternalError("No plan engine enabled");
264 #endif
265 }
266 }  // anonymous namespace
267 
RunLocalComputation(SimpleTaskEnvironment * env_deps,EventPublisher * event_publisher,LogManager * log_manager,const Flags * flags,const std::string & session_name,const std::string & plan_uri,const std::string & input_dir_uri,const std::string & output_dir_uri,const absl::flat_hash_map<std::string,std::string> & input_resources)268 absl::Status RunLocalComputation(
269     SimpleTaskEnvironment* env_deps, EventPublisher* event_publisher,
270     LogManager* log_manager, const Flags* flags,
271     const std::string& session_name, const std::string& plan_uri,
272     const std::string& input_dir_uri, const std::string& output_dir_uri,
273     const absl::flat_hash_map<std::string, std::string>& input_resources) {
274   auto opstats_logger = engine::CreateOpStatsLogger(
275       env_deps->GetBaseDir(), flags, log_manager, session_name,
276       /*population_name=*/"");
277   SelectorContext selector_context;
278   selector_context.mutable_computation_properties()->set_session_name(
279       session_name);
280   LocalComputation computation = LocalComputation();
281   computation.set_input_dir(input_dir_uri);
282   computation.set_output_dir(output_dir_uri);
283   computation.mutable_input_resource_map()->insert(input_resources.begin(),
284                                                    input_resources.end());
285   *selector_context.mutable_computation_properties()->mutable_local_compute() =
286       computation;
287   PhaseLoggerImpl phase_logger(event_publisher, opstats_logger.get(),
288                                log_manager, flags);
289   return RunLocalComputation(phase_logger, env_deps, log_manager,
290                              opstats_logger.get(), flags, plan_uri,
291                              input_dir_uri, output_dir_uri, input_resources,
292                              selector_context);
293 }
294 
RunLocalComputation(PhaseLogger & phase_logger,SimpleTaskEnvironment * env_deps,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,const std::string & plan_uri,const std::string & input_dir_uri,const std::string & output_dir_uri,const absl::flat_hash_map<std::string,std::string> & input_resources,const SelectorContext & selector_context)295 absl::Status RunLocalComputation(
296     PhaseLogger& phase_logger, SimpleTaskEnvironment* env_deps,
297     LogManager* log_manager, OpStatsLogger* opstats_logger, const Flags* flags,
298     const std::string& plan_uri, const std::string& input_dir_uri,
299     const std::string& output_dir_uri,
300     const absl::flat_hash_map<std::string, std::string>& input_resources,
301     const SelectorContext& selector_context) {
302   absl::Time reference_time = absl::Now();
303   absl::Duration polling_period =
304       absl::Milliseconds(flags->condition_polling_period_millis());
305   std::function<bool()> should_abort = [env_deps, polling_period]() {
306     return env_deps->ShouldAbort(absl::Now(), polling_period);
307   };
308   // Check if the device conditions allow running a local computation.
309   if (should_abort()) {
310     std::string message =
311         "Device conditions not satisfied, aborting local computation";
312     FCP_LOG(INFO) << message;
313     phase_logger.LogTaskNotStarted(message);
314     return absl::CancelledError("");
315   }
316   // Local compute plans can use example iterators from the
317   // SimpleTaskEnvironment and those reading the OpStats DB.
318   opstats::OpStatsExampleIteratorFactory opstats_example_iterator_factory(
319       opstats_logger, log_manager,
320       flags->opstats_last_successful_contribution_criteria());
321   std::unique_ptr<engine::ExampleIteratorFactory> env_example_iterator_factory =
322       CreateSimpleTaskEnvironmentIteratorFactory(env_deps, selector_context);
323   std::vector<engine::ExampleIteratorFactory*> example_iterator_factories{
324       &opstats_example_iterator_factory, env_example_iterator_factory.get()};
325 
326   fcp::client::InterruptibleRunner::TimingConfig timing_config = {
327       .polling_period = polling_period,
328       .graceful_shutdown_period = absl::Milliseconds(
329           flags->tf_execution_teardown_grace_period_millis()),
330       .extended_shutdown_period = absl::Milliseconds(
331           flags->tf_execution_teardown_extended_period_millis()),
332   };
333 
334   absl::Time run_plan_start_time = absl::Now();
335   phase_logger.LogComputationStarted();
336 
337   absl::StatusOr<std::string> plan_str = fcp::ReadFileToString(plan_uri);
338   if (!plan_str.ok()) {
339     phase_logger.LogComputationIOError(plan_str.status(), ExampleStats(),
340                                        NetworkStats(), run_plan_start_time);
341     return plan_str.status();
342   }
343 
344   ClientOnlyPlan plan;
345   if (!plan.ParseFromString(*plan_str)) {
346     absl::Status error_status =
347         absl::InvalidArgumentError("could not parse received plan");
348     phase_logger.LogComputationInvalidArgument(
349         error_status, ExampleStats(), NetworkStats(), run_plan_start_time);
350     return error_status;
351   }
352 
353   std::vector<std::string> output_names;
354   std::vector<tensorflow::Tensor> output_tensors;
355   return RunPlanWithTensorflowSpec(
356       phase_logger, example_iterator_factories, should_abort, log_manager,
357       opstats_logger, flags, plan, input_dir_uri, output_dir_uri,
358       input_resources, timing_config, run_plan_start_time, reference_time);
359 }
360 
361 }  // namespace client
362 }  // namespace fcp
363