xref: /aosp_15_r20/external/federated-compute/fcp/client/lc_runner.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2020 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker #include "fcp/client/lc_runner.h"
17*14675a02SAndroid Build Coastguard Worker 
18*14675a02SAndroid Build Coastguard Worker #include <functional>
19*14675a02SAndroid Build Coastguard Worker #include <map>
20*14675a02SAndroid Build Coastguard Worker #include <memory>
21*14675a02SAndroid Build Coastguard Worker #include <string>
22*14675a02SAndroid Build Coastguard Worker #include <utility>
23*14675a02SAndroid Build Coastguard Worker #include <vector>
24*14675a02SAndroid Build Coastguard Worker 
25*14675a02SAndroid Build Coastguard Worker #include "absl/container/flat_hash_map.h"
26*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h"
27*14675a02SAndroid Build Coastguard Worker #include "absl/status/statusor.h"
28*14675a02SAndroid Build Coastguard Worker #include "absl/time/time.h"
29*14675a02SAndroid Build Coastguard Worker #include "fcp/base/monitoring.h"
30*14675a02SAndroid Build Coastguard Worker #include "fcp/base/platform.h"
31*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/example_iterator_factory.h"
32*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/plan_engine_helpers.h"
33*14675a02SAndroid Build Coastguard Worker 
34*14675a02SAndroid Build Coastguard Worker #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
35*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/simple_plan_engine.h"
36*14675a02SAndroid Build Coastguard Worker #endif
37*14675a02SAndroid Build Coastguard Worker 
38*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/tflite_plan_engine.h"
39*14675a02SAndroid Build Coastguard Worker #include "fcp/client/opstats/opstats_example_store.h"
40*14675a02SAndroid Build Coastguard Worker #include "fcp/client/phase_logger_impl.h"
41*14675a02SAndroid Build Coastguard Worker #include "fcp/client/selector_context.pb.h"
42*14675a02SAndroid Build Coastguard Worker #include "fcp/protos/plan.pb.h"
43*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/tensor.h"
44*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/tensor.pb.h"
45*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/tensor_shape.pb.h"
46*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/protobuf/struct.pb.h"
47*14675a02SAndroid Build Coastguard Worker 
48*14675a02SAndroid Build Coastguard Worker namespace fcp {
49*14675a02SAndroid Build Coastguard Worker namespace client {
50*14675a02SAndroid Build Coastguard Worker 
51*14675a02SAndroid Build Coastguard Worker using ::fcp::client::opstats::OpStatsLogger;
52*14675a02SAndroid Build Coastguard Worker using ::google::internal::federated::plan::ClientOnlyPlan;
53*14675a02SAndroid Build Coastguard Worker using ::google::internal::federated::plan::LocalComputeIORouter;
54*14675a02SAndroid Build Coastguard Worker 
55*14675a02SAndroid Build Coastguard Worker using TfLiteInputs = absl::flat_hash_map<std::string, std::string>;
56*14675a02SAndroid Build Coastguard Worker using TfMobileInputs = std::vector<std::pair<std::string, tensorflow::Tensor>>;
57*14675a02SAndroid Build Coastguard Worker 
58*14675a02SAndroid Build Coastguard Worker namespace {
59*14675a02SAndroid Build Coastguard Worker #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
60*14675a02SAndroid Build Coastguard Worker absl::StatusOr<std::unique_ptr<TfMobileInputs>>
ConstructInputsForTensorflowSpecPlan(const LocalComputeIORouter & local_compute,const std::string & input_dir_uri,const std::string & output_dir_uri,const absl::flat_hash_map<std::string,std::string> & input_resources)61*14675a02SAndroid Build Coastguard Worker ConstructInputsForTensorflowSpecPlan(
62*14675a02SAndroid Build Coastguard Worker     const LocalComputeIORouter& local_compute, const std::string& input_dir_uri,
63*14675a02SAndroid Build Coastguard Worker     const std::string& output_dir_uri,
64*14675a02SAndroid Build Coastguard Worker     const absl::flat_hash_map<std::string, std::string>& input_resources) {
65*14675a02SAndroid Build Coastguard Worker   auto inputs = std::make_unique<
66*14675a02SAndroid Build Coastguard Worker       std::vector<std::pair<std::string, tensorflow::Tensor>>>();
67*14675a02SAndroid Build Coastguard Worker   if (local_compute.has_multiple_input_resources()) {
68*14675a02SAndroid Build Coastguard Worker     if (!input_dir_uri.empty()) {
69*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError(
70*14675a02SAndroid Build Coastguard Worker           "Both input dir and input resources are provided.");
71*14675a02SAndroid Build Coastguard Worker     }
72*14675a02SAndroid Build Coastguard Worker     auto input_resource_tensor_name_map =
73*14675a02SAndroid Build Coastguard Worker         local_compute.multiple_input_resources()
74*14675a02SAndroid Build Coastguard Worker             .input_resource_tensor_name_map();
75*14675a02SAndroid Build Coastguard Worker     for (const auto& resource : input_resources) {
76*14675a02SAndroid Build Coastguard Worker       tensorflow::Tensor resource_tensor(tensorflow::DT_STRING, {});
77*14675a02SAndroid Build Coastguard Worker       resource_tensor.scalar<tensorflow::tstring>()() = resource.second;
78*14675a02SAndroid Build Coastguard Worker       if (!input_resource_tensor_name_map.contains(resource.first)) {
79*14675a02SAndroid Build Coastguard Worker         return absl::InvalidArgumentError(
80*14675a02SAndroid Build Coastguard Worker             absl::StrCat("User provided input resource:", resource.first,
81*14675a02SAndroid Build Coastguard Worker                          " is missing in LocalComputeIORouter."));
82*14675a02SAndroid Build Coastguard Worker       }
83*14675a02SAndroid Build Coastguard Worker       std::string tensor_name = input_resource_tensor_name_map[resource.first];
84*14675a02SAndroid Build Coastguard Worker       inputs->push_back({tensor_name, resource_tensor});
85*14675a02SAndroid Build Coastguard Worker     }
86*14675a02SAndroid Build Coastguard Worker   } else {
87*14675a02SAndroid Build Coastguard Worker     tensorflow::Tensor input_dirpath(tensorflow::DT_STRING, {});
88*14675a02SAndroid Build Coastguard Worker     input_dirpath.scalar<tensorflow::tstring>()() = input_dir_uri;
89*14675a02SAndroid Build Coastguard Worker     inputs->push_back({local_compute.input_dir_tensor_name(), input_dirpath});
90*14675a02SAndroid Build Coastguard Worker   }
91*14675a02SAndroid Build Coastguard Worker   tensorflow::Tensor output_dirpath(tensorflow::DT_STRING, {});
92*14675a02SAndroid Build Coastguard Worker   output_dirpath.scalar<tensorflow::tstring>()() = output_dir_uri;
93*14675a02SAndroid Build Coastguard Worker   inputs->push_back({local_compute.output_dir_tensor_name(), output_dirpath});
94*14675a02SAndroid Build Coastguard Worker   return inputs;
95*14675a02SAndroid Build Coastguard Worker }
96*14675a02SAndroid Build Coastguard Worker #endif
97*14675a02SAndroid Build Coastguard Worker 
ConstructInputsForTFLitePlan(const LocalComputeIORouter & local_compute,const std::string & input_dir_uri,const std::string & output_dir_uri,const absl::flat_hash_map<std::string,std::string> & input_resources)98*14675a02SAndroid Build Coastguard Worker absl::StatusOr<std::unique_ptr<TfLiteInputs>> ConstructInputsForTFLitePlan(
99*14675a02SAndroid Build Coastguard Worker     const LocalComputeIORouter& local_compute, const std::string& input_dir_uri,
100*14675a02SAndroid Build Coastguard Worker     const std::string& output_dir_uri,
101*14675a02SAndroid Build Coastguard Worker     const absl::flat_hash_map<std::string, std::string>& input_resources) {
102*14675a02SAndroid Build Coastguard Worker   auto inputs = std::make_unique<TfLiteInputs>();
103*14675a02SAndroid Build Coastguard Worker   if (local_compute.has_multiple_input_resources()) {
104*14675a02SAndroid Build Coastguard Worker     if (!input_dir_uri.empty()) {
105*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError(
106*14675a02SAndroid Build Coastguard Worker           "Both input dir and input resources are provided.");
107*14675a02SAndroid Build Coastguard Worker     }
108*14675a02SAndroid Build Coastguard Worker     auto input_resource_tensor_name_map =
109*14675a02SAndroid Build Coastguard Worker         local_compute.multiple_input_resources()
110*14675a02SAndroid Build Coastguard Worker             .input_resource_tensor_name_map();
111*14675a02SAndroid Build Coastguard Worker     for (const auto& resource : input_resources) {
112*14675a02SAndroid Build Coastguard Worker       if (!input_resource_tensor_name_map.contains(resource.first)) {
113*14675a02SAndroid Build Coastguard Worker         // If the user provided more input resources than required in the
114*14675a02SAndroid Build Coastguard Worker         // LocalComputeIORouter, we simply continue without throwing an error.
115*14675a02SAndroid Build Coastguard Worker         // In this way, the user could update their scheduling logic separately
116*14675a02SAndroid Build Coastguard Worker         // from their local computation definitions.
117*14675a02SAndroid Build Coastguard Worker         continue;
118*14675a02SAndroid Build Coastguard Worker       }
119*14675a02SAndroid Build Coastguard Worker       std::string tensor_name = input_resource_tensor_name_map[resource.first];
120*14675a02SAndroid Build Coastguard Worker       (*inputs)[tensor_name] = resource.second;
121*14675a02SAndroid Build Coastguard Worker     }
122*14675a02SAndroid Build Coastguard Worker   } else {
123*14675a02SAndroid Build Coastguard Worker     (*inputs)[local_compute.input_dir_tensor_name()] = input_dir_uri;
124*14675a02SAndroid Build Coastguard Worker   }
125*14675a02SAndroid Build Coastguard Worker   (*inputs)[local_compute.output_dir_tensor_name()] = output_dir_uri;
126*14675a02SAndroid Build Coastguard Worker   return inputs;
127*14675a02SAndroid Build Coastguard Worker }
128*14675a02SAndroid Build Coastguard Worker 
LogComputationOutcome(engine::PlanResult plan_result,PhaseLogger & phase_logger,absl::Time run_plan_start_time,absl::Time reference_time)129*14675a02SAndroid Build Coastguard Worker void LogComputationOutcome(engine::PlanResult plan_result,
130*14675a02SAndroid Build Coastguard Worker                            PhaseLogger& phase_logger,
131*14675a02SAndroid Build Coastguard Worker                            absl::Time run_plan_start_time,
132*14675a02SAndroid Build Coastguard Worker                            absl::Time reference_time) {
133*14675a02SAndroid Build Coastguard Worker   switch (plan_result.outcome) {
134*14675a02SAndroid Build Coastguard Worker     case engine::PlanOutcome::kSuccess:
135*14675a02SAndroid Build Coastguard Worker       phase_logger.LogComputationCompleted(plan_result.example_stats,
136*14675a02SAndroid Build Coastguard Worker                                            NetworkStats(), run_plan_start_time,
137*14675a02SAndroid Build Coastguard Worker                                            reference_time);
138*14675a02SAndroid Build Coastguard Worker       break;
139*14675a02SAndroid Build Coastguard Worker     case engine::PlanOutcome::kInterrupted:
140*14675a02SAndroid Build Coastguard Worker       phase_logger.LogComputationInterrupted(
141*14675a02SAndroid Build Coastguard Worker           plan_result.original_status, plan_result.example_stats,
142*14675a02SAndroid Build Coastguard Worker           NetworkStats(), run_plan_start_time, reference_time);
143*14675a02SAndroid Build Coastguard Worker       break;
144*14675a02SAndroid Build Coastguard Worker     case engine::PlanOutcome::kInvalidArgument:
145*14675a02SAndroid Build Coastguard Worker       phase_logger.LogComputationInvalidArgument(
146*14675a02SAndroid Build Coastguard Worker           plan_result.original_status, plan_result.example_stats,
147*14675a02SAndroid Build Coastguard Worker           NetworkStats(), run_plan_start_time);
148*14675a02SAndroid Build Coastguard Worker       break;
149*14675a02SAndroid Build Coastguard Worker     case engine::PlanOutcome::kTensorflowError:
150*14675a02SAndroid Build Coastguard Worker       phase_logger.LogComputationTensorflowError(
151*14675a02SAndroid Build Coastguard Worker           std::move(plan_result.original_status), plan_result.example_stats,
152*14675a02SAndroid Build Coastguard Worker           NetworkStats(), run_plan_start_time, reference_time);
153*14675a02SAndroid Build Coastguard Worker       break;
154*14675a02SAndroid Build Coastguard Worker     case engine::PlanOutcome::kExampleIteratorError:
155*14675a02SAndroid Build Coastguard Worker       phase_logger.LogComputationExampleIteratorError(
156*14675a02SAndroid Build Coastguard Worker           plan_result.original_status, plan_result.example_stats,
157*14675a02SAndroid Build Coastguard Worker           NetworkStats(), run_plan_start_time);
158*14675a02SAndroid Build Coastguard Worker       break;
159*14675a02SAndroid Build Coastguard Worker   }
160*14675a02SAndroid Build Coastguard Worker }
161*14675a02SAndroid Build Coastguard Worker 
162*14675a02SAndroid Build Coastguard Worker // Creates an ExampleIteratorFactory that routes queries to the
163*14675a02SAndroid Build Coastguard Worker // SimpleTaskEnvironment::CreateExampleIterator() method.
164*14675a02SAndroid Build Coastguard Worker std::unique_ptr<engine::ExampleIteratorFactory>
CreateSimpleTaskEnvironmentIteratorFactory(SimpleTaskEnvironment * task_env,const SelectorContext & selector_context)165*14675a02SAndroid Build Coastguard Worker CreateSimpleTaskEnvironmentIteratorFactory(
166*14675a02SAndroid Build Coastguard Worker     SimpleTaskEnvironment* task_env, const SelectorContext& selector_context) {
167*14675a02SAndroid Build Coastguard Worker   return std::make_unique<engine::FunctionalExampleIteratorFactory>(
168*14675a02SAndroid Build Coastguard Worker       /*can_handle_func=*/
169*14675a02SAndroid Build Coastguard Worker       [](const google::internal::federated::plan::ExampleSelector&) {
170*14675a02SAndroid Build Coastguard Worker         // The SimpleTaskEnvironment-based ExampleIteratorFactory should
171*14675a02SAndroid Build Coastguard Worker         // be the catch-all factory that is able to handle all queries
172*14675a02SAndroid Build Coastguard Worker         // that no other ExampleIteratorFactory is able to handle.
173*14675a02SAndroid Build Coastguard Worker         return true;
174*14675a02SAndroid Build Coastguard Worker       },
175*14675a02SAndroid Build Coastguard Worker       /*create_iterator_func=*/
176*14675a02SAndroid Build Coastguard Worker       [task_env, selector_context](
177*14675a02SAndroid Build Coastguard Worker           const google::internal::federated::plan::ExampleSelector&
178*14675a02SAndroid Build Coastguard Worker               example_selector) {
179*14675a02SAndroid Build Coastguard Worker         return task_env->CreateExampleIterator(example_selector,
180*14675a02SAndroid Build Coastguard Worker                                                selector_context);
181*14675a02SAndroid Build Coastguard Worker       },
182*14675a02SAndroid Build Coastguard Worker       /*should_collect_stats=*/true);
183*14675a02SAndroid Build Coastguard Worker }
184*14675a02SAndroid Build Coastguard Worker 
RunPlanWithTensorflowSpec(PhaseLogger & phase_logger,std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,std::function<bool ()> should_abort,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,const ClientOnlyPlan & client_plan,const std::string & input_dir_uri,const std::string & output_dir_uri,const absl::flat_hash_map<std::string,std::string> & input_resources,const fcp::client::InterruptibleRunner::TimingConfig & timing_config,const absl::Time run_plan_start_time,const absl::Time reference_time)185*14675a02SAndroid Build Coastguard Worker absl::Status RunPlanWithTensorflowSpec(
186*14675a02SAndroid Build Coastguard Worker     PhaseLogger& phase_logger,
187*14675a02SAndroid Build Coastguard Worker     std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
188*14675a02SAndroid Build Coastguard Worker     std::function<bool()> should_abort, LogManager* log_manager,
189*14675a02SAndroid Build Coastguard Worker     OpStatsLogger* opstats_logger, const Flags* flags,
190*14675a02SAndroid Build Coastguard Worker     const ClientOnlyPlan& client_plan, const std::string& input_dir_uri,
191*14675a02SAndroid Build Coastguard Worker     const std::string& output_dir_uri,
192*14675a02SAndroid Build Coastguard Worker     const absl::flat_hash_map<std::string, std::string>& input_resources,
193*14675a02SAndroid Build Coastguard Worker     const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
194*14675a02SAndroid Build Coastguard Worker     const absl::Time run_plan_start_time, const absl::Time reference_time) {
195*14675a02SAndroid Build Coastguard Worker   // Check that this is a TensorflowSpec-based plan for local computation.
196*14675a02SAndroid Build Coastguard Worker   if (!client_plan.phase().has_tensorflow_spec()) {
197*14675a02SAndroid Build Coastguard Worker     absl::Status error_status =
198*14675a02SAndroid Build Coastguard Worker         absl::InvalidArgumentError("Plan without TensorflowSpec");
199*14675a02SAndroid Build Coastguard Worker     phase_logger.LogComputationInvalidArgument(
200*14675a02SAndroid Build Coastguard Worker         error_status, ExampleStats(), NetworkStats(), run_plan_start_time);
201*14675a02SAndroid Build Coastguard Worker     return error_status;
202*14675a02SAndroid Build Coastguard Worker   }
203*14675a02SAndroid Build Coastguard Worker   if (!client_plan.phase().has_local_compute()) {
204*14675a02SAndroid Build Coastguard Worker     absl::Status error_status =
205*14675a02SAndroid Build Coastguard Worker         absl::InvalidArgumentError("Invalid TensorflowSpec-based plan");
206*14675a02SAndroid Build Coastguard Worker     phase_logger.LogComputationInvalidArgument(
207*14675a02SAndroid Build Coastguard Worker         error_status, ExampleStats(), NetworkStats(), run_plan_start_time);
208*14675a02SAndroid Build Coastguard Worker     return error_status;
209*14675a02SAndroid Build Coastguard Worker   }
210*14675a02SAndroid Build Coastguard Worker 
211*14675a02SAndroid Build Coastguard Worker   // Run plan
212*14675a02SAndroid Build Coastguard Worker   std::vector<std::string> output_names_unused;
213*14675a02SAndroid Build Coastguard Worker 
214*14675a02SAndroid Build Coastguard Worker   if (!client_plan.tflite_graph().empty()) {
215*14675a02SAndroid Build Coastguard Worker     log_manager->LogDiag(
216*14675a02SAndroid Build Coastguard Worker         ProdDiagCode::BACKGROUND_TRAINING_TFLITE_MODEL_INCLUDED);
217*14675a02SAndroid Build Coastguard Worker   }
218*14675a02SAndroid Build Coastguard Worker 
219*14675a02SAndroid Build Coastguard Worker   if (flags->use_tflite_training() && !client_plan.tflite_graph().empty()) {
220*14675a02SAndroid Build Coastguard Worker     auto inputs = ConstructInputsForTFLitePlan(
221*14675a02SAndroid Build Coastguard Worker         client_plan.phase().local_compute(), input_dir_uri, output_dir_uri,
222*14675a02SAndroid Build Coastguard Worker         input_resources);
223*14675a02SAndroid Build Coastguard Worker     if (!inputs.ok()) {
224*14675a02SAndroid Build Coastguard Worker       phase_logger.LogComputationInvalidArgument(
225*14675a02SAndroid Build Coastguard Worker           inputs.status(), ExampleStats(), NetworkStats(), run_plan_start_time);
226*14675a02SAndroid Build Coastguard Worker       return inputs.status();
227*14675a02SAndroid Build Coastguard Worker     }
228*14675a02SAndroid Build Coastguard Worker     engine::TfLitePlanEngine plan_engine(example_iterator_factories,
229*14675a02SAndroid Build Coastguard Worker                                          should_abort, log_manager,
230*14675a02SAndroid Build Coastguard Worker                                          opstats_logger, flags, &timing_config);
231*14675a02SAndroid Build Coastguard Worker     engine::PlanResult plan_result = plan_engine.RunPlan(
232*14675a02SAndroid Build Coastguard Worker         client_plan.phase().tensorflow_spec(), client_plan.tflite_graph(),
233*14675a02SAndroid Build Coastguard Worker         std::move(*inputs), output_names_unused);
234*14675a02SAndroid Build Coastguard Worker     engine::PlanOutcome outcome = plan_result.outcome;
235*14675a02SAndroid Build Coastguard Worker     LogComputationOutcome(std::move(plan_result), phase_logger,
236*14675a02SAndroid Build Coastguard Worker                           run_plan_start_time, reference_time);
237*14675a02SAndroid Build Coastguard Worker     return ConvertPlanOutcomeToStatus(outcome);
238*14675a02SAndroid Build Coastguard Worker   }
239*14675a02SAndroid Build Coastguard Worker 
240*14675a02SAndroid Build Coastguard Worker #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
241*14675a02SAndroid Build Coastguard Worker   // Construct input tensors based on the values in the LocalComputeIORouter
242*14675a02SAndroid Build Coastguard Worker   // message.
243*14675a02SAndroid Build Coastguard Worker   auto inputs = ConstructInputsForTensorflowSpecPlan(
244*14675a02SAndroid Build Coastguard Worker       client_plan.phase().local_compute(), input_dir_uri, output_dir_uri,
245*14675a02SAndroid Build Coastguard Worker       input_resources);
246*14675a02SAndroid Build Coastguard Worker   if (!inputs.ok()) {
247*14675a02SAndroid Build Coastguard Worker     phase_logger.LogComputationInvalidArgument(
248*14675a02SAndroid Build Coastguard Worker         inputs.status(), ExampleStats(), NetworkStats(), run_plan_start_time);
249*14675a02SAndroid Build Coastguard Worker     return inputs.status();
250*14675a02SAndroid Build Coastguard Worker   }
251*14675a02SAndroid Build Coastguard Worker   engine::SimplePlanEngine plan_engine(
252*14675a02SAndroid Build Coastguard Worker       example_iterator_factories, should_abort, log_manager, opstats_logger,
253*14675a02SAndroid Build Coastguard Worker       &timing_config, flags->support_constant_tf_inputs());
254*14675a02SAndroid Build Coastguard Worker   engine::PlanResult plan_result = plan_engine.RunPlan(
255*14675a02SAndroid Build Coastguard Worker       client_plan.phase().tensorflow_spec(), client_plan.graph(),
256*14675a02SAndroid Build Coastguard Worker       client_plan.tensorflow_config_proto(), std::move(*inputs),
257*14675a02SAndroid Build Coastguard Worker       output_names_unused);
258*14675a02SAndroid Build Coastguard Worker   engine::PlanOutcome outcome = plan_result.outcome;
259*14675a02SAndroid Build Coastguard Worker   LogComputationOutcome(std::move(plan_result), phase_logger,
260*14675a02SAndroid Build Coastguard Worker                         run_plan_start_time, reference_time);
261*14675a02SAndroid Build Coastguard Worker   return ConvertPlanOutcomeToStatus(outcome);
262*14675a02SAndroid Build Coastguard Worker #else
263*14675a02SAndroid Build Coastguard Worker   return absl::InternalError("No plan engine enabled");
264*14675a02SAndroid Build Coastguard Worker #endif
265*14675a02SAndroid Build Coastguard Worker }
266*14675a02SAndroid Build Coastguard Worker }  // anonymous namespace
267*14675a02SAndroid Build Coastguard Worker 
RunLocalComputation(SimpleTaskEnvironment * env_deps,EventPublisher * event_publisher,LogManager * log_manager,const Flags * flags,const std::string & session_name,const std::string & plan_uri,const std::string & input_dir_uri,const std::string & output_dir_uri,const absl::flat_hash_map<std::string,std::string> & input_resources)268*14675a02SAndroid Build Coastguard Worker absl::Status RunLocalComputation(
269*14675a02SAndroid Build Coastguard Worker     SimpleTaskEnvironment* env_deps, EventPublisher* event_publisher,
270*14675a02SAndroid Build Coastguard Worker     LogManager* log_manager, const Flags* flags,
271*14675a02SAndroid Build Coastguard Worker     const std::string& session_name, const std::string& plan_uri,
272*14675a02SAndroid Build Coastguard Worker     const std::string& input_dir_uri, const std::string& output_dir_uri,
273*14675a02SAndroid Build Coastguard Worker     const absl::flat_hash_map<std::string, std::string>& input_resources) {
274*14675a02SAndroid Build Coastguard Worker   auto opstats_logger = engine::CreateOpStatsLogger(
275*14675a02SAndroid Build Coastguard Worker       env_deps->GetBaseDir(), flags, log_manager, session_name,
276*14675a02SAndroid Build Coastguard Worker       /*population_name=*/"");
277*14675a02SAndroid Build Coastguard Worker   SelectorContext selector_context;
278*14675a02SAndroid Build Coastguard Worker   selector_context.mutable_computation_properties()->set_session_name(
279*14675a02SAndroid Build Coastguard Worker       session_name);
280*14675a02SAndroid Build Coastguard Worker   LocalComputation computation = LocalComputation();
281*14675a02SAndroid Build Coastguard Worker   computation.set_input_dir(input_dir_uri);
282*14675a02SAndroid Build Coastguard Worker   computation.set_output_dir(output_dir_uri);
283*14675a02SAndroid Build Coastguard Worker   computation.mutable_input_resource_map()->insert(input_resources.begin(),
284*14675a02SAndroid Build Coastguard Worker                                                    input_resources.end());
285*14675a02SAndroid Build Coastguard Worker   *selector_context.mutable_computation_properties()->mutable_local_compute() =
286*14675a02SAndroid Build Coastguard Worker       computation;
287*14675a02SAndroid Build Coastguard Worker   PhaseLoggerImpl phase_logger(event_publisher, opstats_logger.get(),
288*14675a02SAndroid Build Coastguard Worker                                log_manager, flags);
289*14675a02SAndroid Build Coastguard Worker   return RunLocalComputation(phase_logger, env_deps, log_manager,
290*14675a02SAndroid Build Coastguard Worker                              opstats_logger.get(), flags, plan_uri,
291*14675a02SAndroid Build Coastguard Worker                              input_dir_uri, output_dir_uri, input_resources,
292*14675a02SAndroid Build Coastguard Worker                              selector_context);
293*14675a02SAndroid Build Coastguard Worker }
294*14675a02SAndroid Build Coastguard Worker 
RunLocalComputation(PhaseLogger & phase_logger,SimpleTaskEnvironment * env_deps,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,const std::string & plan_uri,const std::string & input_dir_uri,const std::string & output_dir_uri,const absl::flat_hash_map<std::string,std::string> & input_resources,const SelectorContext & selector_context)295*14675a02SAndroid Build Coastguard Worker absl::Status RunLocalComputation(
296*14675a02SAndroid Build Coastguard Worker     PhaseLogger& phase_logger, SimpleTaskEnvironment* env_deps,
297*14675a02SAndroid Build Coastguard Worker     LogManager* log_manager, OpStatsLogger* opstats_logger, const Flags* flags,
298*14675a02SAndroid Build Coastguard Worker     const std::string& plan_uri, const std::string& input_dir_uri,
299*14675a02SAndroid Build Coastguard Worker     const std::string& output_dir_uri,
300*14675a02SAndroid Build Coastguard Worker     const absl::flat_hash_map<std::string, std::string>& input_resources,
301*14675a02SAndroid Build Coastguard Worker     const SelectorContext& selector_context) {
302*14675a02SAndroid Build Coastguard Worker   absl::Time reference_time = absl::Now();
303*14675a02SAndroid Build Coastguard Worker   absl::Duration polling_period =
304*14675a02SAndroid Build Coastguard Worker       absl::Milliseconds(flags->condition_polling_period_millis());
305*14675a02SAndroid Build Coastguard Worker   std::function<bool()> should_abort = [env_deps, polling_period]() {
306*14675a02SAndroid Build Coastguard Worker     return env_deps->ShouldAbort(absl::Now(), polling_period);
307*14675a02SAndroid Build Coastguard Worker   };
308*14675a02SAndroid Build Coastguard Worker   // Check if the device conditions allow running a local computation.
309*14675a02SAndroid Build Coastguard Worker   if (should_abort()) {
310*14675a02SAndroid Build Coastguard Worker     std::string message =
311*14675a02SAndroid Build Coastguard Worker         "Device conditions not satisfied, aborting local computation";
312*14675a02SAndroid Build Coastguard Worker     FCP_LOG(INFO) << message;
313*14675a02SAndroid Build Coastguard Worker     phase_logger.LogTaskNotStarted(message);
314*14675a02SAndroid Build Coastguard Worker     return absl::CancelledError("");
315*14675a02SAndroid Build Coastguard Worker   }
316*14675a02SAndroid Build Coastguard Worker   // Local compute plans can use example iterators from the
317*14675a02SAndroid Build Coastguard Worker   // SimpleTaskEnvironment and those reading the OpStats DB.
318*14675a02SAndroid Build Coastguard Worker   opstats::OpStatsExampleIteratorFactory opstats_example_iterator_factory(
319*14675a02SAndroid Build Coastguard Worker       opstats_logger, log_manager,
320*14675a02SAndroid Build Coastguard Worker       flags->opstats_last_successful_contribution_criteria());
321*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<engine::ExampleIteratorFactory> env_example_iterator_factory =
322*14675a02SAndroid Build Coastguard Worker       CreateSimpleTaskEnvironmentIteratorFactory(env_deps, selector_context);
323*14675a02SAndroid Build Coastguard Worker   std::vector<engine::ExampleIteratorFactory*> example_iterator_factories{
324*14675a02SAndroid Build Coastguard Worker       &opstats_example_iterator_factory, env_example_iterator_factory.get()};
325*14675a02SAndroid Build Coastguard Worker 
326*14675a02SAndroid Build Coastguard Worker   fcp::client::InterruptibleRunner::TimingConfig timing_config = {
327*14675a02SAndroid Build Coastguard Worker       .polling_period = polling_period,
328*14675a02SAndroid Build Coastguard Worker       .graceful_shutdown_period = absl::Milliseconds(
329*14675a02SAndroid Build Coastguard Worker           flags->tf_execution_teardown_grace_period_millis()),
330*14675a02SAndroid Build Coastguard Worker       .extended_shutdown_period = absl::Milliseconds(
331*14675a02SAndroid Build Coastguard Worker           flags->tf_execution_teardown_extended_period_millis()),
332*14675a02SAndroid Build Coastguard Worker   };
333*14675a02SAndroid Build Coastguard Worker 
334*14675a02SAndroid Build Coastguard Worker   absl::Time run_plan_start_time = absl::Now();
335*14675a02SAndroid Build Coastguard Worker   phase_logger.LogComputationStarted();
336*14675a02SAndroid Build Coastguard Worker 
337*14675a02SAndroid Build Coastguard Worker   absl::StatusOr<std::string> plan_str = fcp::ReadFileToString(plan_uri);
338*14675a02SAndroid Build Coastguard Worker   if (!plan_str.ok()) {
339*14675a02SAndroid Build Coastguard Worker     phase_logger.LogComputationIOError(plan_str.status(), ExampleStats(),
340*14675a02SAndroid Build Coastguard Worker                                        NetworkStats(), run_plan_start_time);
341*14675a02SAndroid Build Coastguard Worker     return plan_str.status();
342*14675a02SAndroid Build Coastguard Worker   }
343*14675a02SAndroid Build Coastguard Worker 
344*14675a02SAndroid Build Coastguard Worker   ClientOnlyPlan plan;
345*14675a02SAndroid Build Coastguard Worker   if (!plan.ParseFromString(*plan_str)) {
346*14675a02SAndroid Build Coastguard Worker     absl::Status error_status =
347*14675a02SAndroid Build Coastguard Worker         absl::InvalidArgumentError("could not parse received plan");
348*14675a02SAndroid Build Coastguard Worker     phase_logger.LogComputationInvalidArgument(
349*14675a02SAndroid Build Coastguard Worker         error_status, ExampleStats(), NetworkStats(), run_plan_start_time);
350*14675a02SAndroid Build Coastguard Worker     return error_status;
351*14675a02SAndroid Build Coastguard Worker   }
352*14675a02SAndroid Build Coastguard Worker 
353*14675a02SAndroid Build Coastguard Worker   std::vector<std::string> output_names;
354*14675a02SAndroid Build Coastguard Worker   std::vector<tensorflow::Tensor> output_tensors;
355*14675a02SAndroid Build Coastguard Worker   return RunPlanWithTensorflowSpec(
356*14675a02SAndroid Build Coastguard Worker       phase_logger, example_iterator_factories, should_abort, log_manager,
357*14675a02SAndroid Build Coastguard Worker       opstats_logger, flags, plan, input_dir_uri, output_dir_uri,
358*14675a02SAndroid Build Coastguard Worker       input_resources, timing_config, run_plan_start_time, reference_time);
359*14675a02SAndroid Build Coastguard Worker }
360*14675a02SAndroid Build Coastguard Worker 
361*14675a02SAndroid Build Coastguard Worker }  // namespace client
362*14675a02SAndroid Build Coastguard Worker }  // namespace fcp
363