1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker * Copyright 2020 Google LLC
3*14675a02SAndroid Build Coastguard Worker *
4*14675a02SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker *
8*14675a02SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker *
10*14675a02SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker */
16*14675a02SAndroid Build Coastguard Worker #include "fcp/client/lc_runner.h"
17*14675a02SAndroid Build Coastguard Worker
18*14675a02SAndroid Build Coastguard Worker #include <functional>
19*14675a02SAndroid Build Coastguard Worker #include <map>
20*14675a02SAndroid Build Coastguard Worker #include <memory>
21*14675a02SAndroid Build Coastguard Worker #include <string>
22*14675a02SAndroid Build Coastguard Worker #include <utility>
23*14675a02SAndroid Build Coastguard Worker #include <vector>
24*14675a02SAndroid Build Coastguard Worker
25*14675a02SAndroid Build Coastguard Worker #include "absl/container/flat_hash_map.h"
26*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h"
27*14675a02SAndroid Build Coastguard Worker #include "absl/status/statusor.h"
28*14675a02SAndroid Build Coastguard Worker #include "absl/time/time.h"
29*14675a02SAndroid Build Coastguard Worker #include "fcp/base/monitoring.h"
30*14675a02SAndroid Build Coastguard Worker #include "fcp/base/platform.h"
31*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/example_iterator_factory.h"
32*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/plan_engine_helpers.h"
33*14675a02SAndroid Build Coastguard Worker
34*14675a02SAndroid Build Coastguard Worker #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
35*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/simple_plan_engine.h"
36*14675a02SAndroid Build Coastguard Worker #endif
37*14675a02SAndroid Build Coastguard Worker
38*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/tflite_plan_engine.h"
39*14675a02SAndroid Build Coastguard Worker #include "fcp/client/opstats/opstats_example_store.h"
40*14675a02SAndroid Build Coastguard Worker #include "fcp/client/phase_logger_impl.h"
41*14675a02SAndroid Build Coastguard Worker #include "fcp/client/selector_context.pb.h"
42*14675a02SAndroid Build Coastguard Worker #include "fcp/protos/plan.pb.h"
43*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/tensor.h"
44*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/tensor.pb.h"
45*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/tensor_shape.pb.h"
46*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/protobuf/struct.pb.h"
47*14675a02SAndroid Build Coastguard Worker
48*14675a02SAndroid Build Coastguard Worker namespace fcp {
49*14675a02SAndroid Build Coastguard Worker namespace client {
50*14675a02SAndroid Build Coastguard Worker
51*14675a02SAndroid Build Coastguard Worker using ::fcp::client::opstats::OpStatsLogger;
52*14675a02SAndroid Build Coastguard Worker using ::google::internal::federated::plan::ClientOnlyPlan;
53*14675a02SAndroid Build Coastguard Worker using ::google::internal::federated::plan::LocalComputeIORouter;
54*14675a02SAndroid Build Coastguard Worker
55*14675a02SAndroid Build Coastguard Worker using TfLiteInputs = absl::flat_hash_map<std::string, std::string>;
56*14675a02SAndroid Build Coastguard Worker using TfMobileInputs = std::vector<std::pair<std::string, tensorflow::Tensor>>;
57*14675a02SAndroid Build Coastguard Worker
58*14675a02SAndroid Build Coastguard Worker namespace {
59*14675a02SAndroid Build Coastguard Worker #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
60*14675a02SAndroid Build Coastguard Worker absl::StatusOr<std::unique_ptr<TfMobileInputs>>
ConstructInputsForTensorflowSpecPlan(const LocalComputeIORouter & local_compute,const std::string & input_dir_uri,const std::string & output_dir_uri,const absl::flat_hash_map<std::string,std::string> & input_resources)61*14675a02SAndroid Build Coastguard Worker ConstructInputsForTensorflowSpecPlan(
62*14675a02SAndroid Build Coastguard Worker const LocalComputeIORouter& local_compute, const std::string& input_dir_uri,
63*14675a02SAndroid Build Coastguard Worker const std::string& output_dir_uri,
64*14675a02SAndroid Build Coastguard Worker const absl::flat_hash_map<std::string, std::string>& input_resources) {
65*14675a02SAndroid Build Coastguard Worker auto inputs = std::make_unique<
66*14675a02SAndroid Build Coastguard Worker std::vector<std::pair<std::string, tensorflow::Tensor>>>();
67*14675a02SAndroid Build Coastguard Worker if (local_compute.has_multiple_input_resources()) {
68*14675a02SAndroid Build Coastguard Worker if (!input_dir_uri.empty()) {
69*14675a02SAndroid Build Coastguard Worker return absl::InvalidArgumentError(
70*14675a02SAndroid Build Coastguard Worker "Both input dir and input resources are provided.");
71*14675a02SAndroid Build Coastguard Worker }
72*14675a02SAndroid Build Coastguard Worker auto input_resource_tensor_name_map =
73*14675a02SAndroid Build Coastguard Worker local_compute.multiple_input_resources()
74*14675a02SAndroid Build Coastguard Worker .input_resource_tensor_name_map();
75*14675a02SAndroid Build Coastguard Worker for (const auto& resource : input_resources) {
76*14675a02SAndroid Build Coastguard Worker tensorflow::Tensor resource_tensor(tensorflow::DT_STRING, {});
77*14675a02SAndroid Build Coastguard Worker resource_tensor.scalar<tensorflow::tstring>()() = resource.second;
78*14675a02SAndroid Build Coastguard Worker if (!input_resource_tensor_name_map.contains(resource.first)) {
79*14675a02SAndroid Build Coastguard Worker return absl::InvalidArgumentError(
80*14675a02SAndroid Build Coastguard Worker absl::StrCat("User provided input resource:", resource.first,
81*14675a02SAndroid Build Coastguard Worker " is missing in LocalComputeIORouter."));
82*14675a02SAndroid Build Coastguard Worker }
83*14675a02SAndroid Build Coastguard Worker std::string tensor_name = input_resource_tensor_name_map[resource.first];
84*14675a02SAndroid Build Coastguard Worker inputs->push_back({tensor_name, resource_tensor});
85*14675a02SAndroid Build Coastguard Worker }
86*14675a02SAndroid Build Coastguard Worker } else {
87*14675a02SAndroid Build Coastguard Worker tensorflow::Tensor input_dirpath(tensorflow::DT_STRING, {});
88*14675a02SAndroid Build Coastguard Worker input_dirpath.scalar<tensorflow::tstring>()() = input_dir_uri;
89*14675a02SAndroid Build Coastguard Worker inputs->push_back({local_compute.input_dir_tensor_name(), input_dirpath});
90*14675a02SAndroid Build Coastguard Worker }
91*14675a02SAndroid Build Coastguard Worker tensorflow::Tensor output_dirpath(tensorflow::DT_STRING, {});
92*14675a02SAndroid Build Coastguard Worker output_dirpath.scalar<tensorflow::tstring>()() = output_dir_uri;
93*14675a02SAndroid Build Coastguard Worker inputs->push_back({local_compute.output_dir_tensor_name(), output_dirpath});
94*14675a02SAndroid Build Coastguard Worker return inputs;
95*14675a02SAndroid Build Coastguard Worker }
96*14675a02SAndroid Build Coastguard Worker #endif
97*14675a02SAndroid Build Coastguard Worker
ConstructInputsForTFLitePlan(const LocalComputeIORouter & local_compute,const std::string & input_dir_uri,const std::string & output_dir_uri,const absl::flat_hash_map<std::string,std::string> & input_resources)98*14675a02SAndroid Build Coastguard Worker absl::StatusOr<std::unique_ptr<TfLiteInputs>> ConstructInputsForTFLitePlan(
99*14675a02SAndroid Build Coastguard Worker const LocalComputeIORouter& local_compute, const std::string& input_dir_uri,
100*14675a02SAndroid Build Coastguard Worker const std::string& output_dir_uri,
101*14675a02SAndroid Build Coastguard Worker const absl::flat_hash_map<std::string, std::string>& input_resources) {
102*14675a02SAndroid Build Coastguard Worker auto inputs = std::make_unique<TfLiteInputs>();
103*14675a02SAndroid Build Coastguard Worker if (local_compute.has_multiple_input_resources()) {
104*14675a02SAndroid Build Coastguard Worker if (!input_dir_uri.empty()) {
105*14675a02SAndroid Build Coastguard Worker return absl::InvalidArgumentError(
106*14675a02SAndroid Build Coastguard Worker "Both input dir and input resources are provided.");
107*14675a02SAndroid Build Coastguard Worker }
108*14675a02SAndroid Build Coastguard Worker auto input_resource_tensor_name_map =
109*14675a02SAndroid Build Coastguard Worker local_compute.multiple_input_resources()
110*14675a02SAndroid Build Coastguard Worker .input_resource_tensor_name_map();
111*14675a02SAndroid Build Coastguard Worker for (const auto& resource : input_resources) {
112*14675a02SAndroid Build Coastguard Worker if (!input_resource_tensor_name_map.contains(resource.first)) {
113*14675a02SAndroid Build Coastguard Worker // If the user provided more input resources than required in the
114*14675a02SAndroid Build Coastguard Worker // LocalComputeIORouter, we simply continue without throwing an error.
115*14675a02SAndroid Build Coastguard Worker // In this way, the user could update their scheduling logic separately
116*14675a02SAndroid Build Coastguard Worker // from their local computation definitions.
117*14675a02SAndroid Build Coastguard Worker continue;
118*14675a02SAndroid Build Coastguard Worker }
119*14675a02SAndroid Build Coastguard Worker std::string tensor_name = input_resource_tensor_name_map[resource.first];
120*14675a02SAndroid Build Coastguard Worker (*inputs)[tensor_name] = resource.second;
121*14675a02SAndroid Build Coastguard Worker }
122*14675a02SAndroid Build Coastguard Worker } else {
123*14675a02SAndroid Build Coastguard Worker (*inputs)[local_compute.input_dir_tensor_name()] = input_dir_uri;
124*14675a02SAndroid Build Coastguard Worker }
125*14675a02SAndroid Build Coastguard Worker (*inputs)[local_compute.output_dir_tensor_name()] = output_dir_uri;
126*14675a02SAndroid Build Coastguard Worker return inputs;
127*14675a02SAndroid Build Coastguard Worker }
128*14675a02SAndroid Build Coastguard Worker
LogComputationOutcome(engine::PlanResult plan_result,PhaseLogger & phase_logger,absl::Time run_plan_start_time,absl::Time reference_time)129*14675a02SAndroid Build Coastguard Worker void LogComputationOutcome(engine::PlanResult plan_result,
130*14675a02SAndroid Build Coastguard Worker PhaseLogger& phase_logger,
131*14675a02SAndroid Build Coastguard Worker absl::Time run_plan_start_time,
132*14675a02SAndroid Build Coastguard Worker absl::Time reference_time) {
133*14675a02SAndroid Build Coastguard Worker switch (plan_result.outcome) {
134*14675a02SAndroid Build Coastguard Worker case engine::PlanOutcome::kSuccess:
135*14675a02SAndroid Build Coastguard Worker phase_logger.LogComputationCompleted(plan_result.example_stats,
136*14675a02SAndroid Build Coastguard Worker NetworkStats(), run_plan_start_time,
137*14675a02SAndroid Build Coastguard Worker reference_time);
138*14675a02SAndroid Build Coastguard Worker break;
139*14675a02SAndroid Build Coastguard Worker case engine::PlanOutcome::kInterrupted:
140*14675a02SAndroid Build Coastguard Worker phase_logger.LogComputationInterrupted(
141*14675a02SAndroid Build Coastguard Worker plan_result.original_status, plan_result.example_stats,
142*14675a02SAndroid Build Coastguard Worker NetworkStats(), run_plan_start_time, reference_time);
143*14675a02SAndroid Build Coastguard Worker break;
144*14675a02SAndroid Build Coastguard Worker case engine::PlanOutcome::kInvalidArgument:
145*14675a02SAndroid Build Coastguard Worker phase_logger.LogComputationInvalidArgument(
146*14675a02SAndroid Build Coastguard Worker plan_result.original_status, plan_result.example_stats,
147*14675a02SAndroid Build Coastguard Worker NetworkStats(), run_plan_start_time);
148*14675a02SAndroid Build Coastguard Worker break;
149*14675a02SAndroid Build Coastguard Worker case engine::PlanOutcome::kTensorflowError:
150*14675a02SAndroid Build Coastguard Worker phase_logger.LogComputationTensorflowError(
151*14675a02SAndroid Build Coastguard Worker std::move(plan_result.original_status), plan_result.example_stats,
152*14675a02SAndroid Build Coastguard Worker NetworkStats(), run_plan_start_time, reference_time);
153*14675a02SAndroid Build Coastguard Worker break;
154*14675a02SAndroid Build Coastguard Worker case engine::PlanOutcome::kExampleIteratorError:
155*14675a02SAndroid Build Coastguard Worker phase_logger.LogComputationExampleIteratorError(
156*14675a02SAndroid Build Coastguard Worker plan_result.original_status, plan_result.example_stats,
157*14675a02SAndroid Build Coastguard Worker NetworkStats(), run_plan_start_time);
158*14675a02SAndroid Build Coastguard Worker break;
159*14675a02SAndroid Build Coastguard Worker }
160*14675a02SAndroid Build Coastguard Worker }
161*14675a02SAndroid Build Coastguard Worker
162*14675a02SAndroid Build Coastguard Worker // Creates an ExampleIteratorFactory that routes queries to the
163*14675a02SAndroid Build Coastguard Worker // SimpleTaskEnvironment::CreateExampleIterator() method.
164*14675a02SAndroid Build Coastguard Worker std::unique_ptr<engine::ExampleIteratorFactory>
CreateSimpleTaskEnvironmentIteratorFactory(SimpleTaskEnvironment * task_env,const SelectorContext & selector_context)165*14675a02SAndroid Build Coastguard Worker CreateSimpleTaskEnvironmentIteratorFactory(
166*14675a02SAndroid Build Coastguard Worker SimpleTaskEnvironment* task_env, const SelectorContext& selector_context) {
167*14675a02SAndroid Build Coastguard Worker return std::make_unique<engine::FunctionalExampleIteratorFactory>(
168*14675a02SAndroid Build Coastguard Worker /*can_handle_func=*/
169*14675a02SAndroid Build Coastguard Worker [](const google::internal::federated::plan::ExampleSelector&) {
170*14675a02SAndroid Build Coastguard Worker // The SimpleTaskEnvironment-based ExampleIteratorFactory should
171*14675a02SAndroid Build Coastguard Worker // be the catch-all factory that is able to handle all queries
172*14675a02SAndroid Build Coastguard Worker // that no other ExampleIteratorFactory is able to handle.
173*14675a02SAndroid Build Coastguard Worker return true;
174*14675a02SAndroid Build Coastguard Worker },
175*14675a02SAndroid Build Coastguard Worker /*create_iterator_func=*/
176*14675a02SAndroid Build Coastguard Worker [task_env, selector_context](
177*14675a02SAndroid Build Coastguard Worker const google::internal::federated::plan::ExampleSelector&
178*14675a02SAndroid Build Coastguard Worker example_selector) {
179*14675a02SAndroid Build Coastguard Worker return task_env->CreateExampleIterator(example_selector,
180*14675a02SAndroid Build Coastguard Worker selector_context);
181*14675a02SAndroid Build Coastguard Worker },
182*14675a02SAndroid Build Coastguard Worker /*should_collect_stats=*/true);
183*14675a02SAndroid Build Coastguard Worker }
184*14675a02SAndroid Build Coastguard Worker
RunPlanWithTensorflowSpec(PhaseLogger & phase_logger,std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,std::function<bool ()> should_abort,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,const ClientOnlyPlan & client_plan,const std::string & input_dir_uri,const std::string & output_dir_uri,const absl::flat_hash_map<std::string,std::string> & input_resources,const fcp::client::InterruptibleRunner::TimingConfig & timing_config,const absl::Time run_plan_start_time,const absl::Time reference_time)185*14675a02SAndroid Build Coastguard Worker absl::Status RunPlanWithTensorflowSpec(
186*14675a02SAndroid Build Coastguard Worker PhaseLogger& phase_logger,
187*14675a02SAndroid Build Coastguard Worker std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
188*14675a02SAndroid Build Coastguard Worker std::function<bool()> should_abort, LogManager* log_manager,
189*14675a02SAndroid Build Coastguard Worker OpStatsLogger* opstats_logger, const Flags* flags,
190*14675a02SAndroid Build Coastguard Worker const ClientOnlyPlan& client_plan, const std::string& input_dir_uri,
191*14675a02SAndroid Build Coastguard Worker const std::string& output_dir_uri,
192*14675a02SAndroid Build Coastguard Worker const absl::flat_hash_map<std::string, std::string>& input_resources,
193*14675a02SAndroid Build Coastguard Worker const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
194*14675a02SAndroid Build Coastguard Worker const absl::Time run_plan_start_time, const absl::Time reference_time) {
195*14675a02SAndroid Build Coastguard Worker // Check that this is a TensorflowSpec-based plan for local computation.
196*14675a02SAndroid Build Coastguard Worker if (!client_plan.phase().has_tensorflow_spec()) {
197*14675a02SAndroid Build Coastguard Worker absl::Status error_status =
198*14675a02SAndroid Build Coastguard Worker absl::InvalidArgumentError("Plan without TensorflowSpec");
199*14675a02SAndroid Build Coastguard Worker phase_logger.LogComputationInvalidArgument(
200*14675a02SAndroid Build Coastguard Worker error_status, ExampleStats(), NetworkStats(), run_plan_start_time);
201*14675a02SAndroid Build Coastguard Worker return error_status;
202*14675a02SAndroid Build Coastguard Worker }
203*14675a02SAndroid Build Coastguard Worker if (!client_plan.phase().has_local_compute()) {
204*14675a02SAndroid Build Coastguard Worker absl::Status error_status =
205*14675a02SAndroid Build Coastguard Worker absl::InvalidArgumentError("Invalid TensorflowSpec-based plan");
206*14675a02SAndroid Build Coastguard Worker phase_logger.LogComputationInvalidArgument(
207*14675a02SAndroid Build Coastguard Worker error_status, ExampleStats(), NetworkStats(), run_plan_start_time);
208*14675a02SAndroid Build Coastguard Worker return error_status;
209*14675a02SAndroid Build Coastguard Worker }
210*14675a02SAndroid Build Coastguard Worker
211*14675a02SAndroid Build Coastguard Worker // Run plan
212*14675a02SAndroid Build Coastguard Worker std::vector<std::string> output_names_unused;
213*14675a02SAndroid Build Coastguard Worker
214*14675a02SAndroid Build Coastguard Worker if (!client_plan.tflite_graph().empty()) {
215*14675a02SAndroid Build Coastguard Worker log_manager->LogDiag(
216*14675a02SAndroid Build Coastguard Worker ProdDiagCode::BACKGROUND_TRAINING_TFLITE_MODEL_INCLUDED);
217*14675a02SAndroid Build Coastguard Worker }
218*14675a02SAndroid Build Coastguard Worker
219*14675a02SAndroid Build Coastguard Worker if (flags->use_tflite_training() && !client_plan.tflite_graph().empty()) {
220*14675a02SAndroid Build Coastguard Worker auto inputs = ConstructInputsForTFLitePlan(
221*14675a02SAndroid Build Coastguard Worker client_plan.phase().local_compute(), input_dir_uri, output_dir_uri,
222*14675a02SAndroid Build Coastguard Worker input_resources);
223*14675a02SAndroid Build Coastguard Worker if (!inputs.ok()) {
224*14675a02SAndroid Build Coastguard Worker phase_logger.LogComputationInvalidArgument(
225*14675a02SAndroid Build Coastguard Worker inputs.status(), ExampleStats(), NetworkStats(), run_plan_start_time);
226*14675a02SAndroid Build Coastguard Worker return inputs.status();
227*14675a02SAndroid Build Coastguard Worker }
228*14675a02SAndroid Build Coastguard Worker engine::TfLitePlanEngine plan_engine(example_iterator_factories,
229*14675a02SAndroid Build Coastguard Worker should_abort, log_manager,
230*14675a02SAndroid Build Coastguard Worker opstats_logger, flags, &timing_config);
231*14675a02SAndroid Build Coastguard Worker engine::PlanResult plan_result = plan_engine.RunPlan(
232*14675a02SAndroid Build Coastguard Worker client_plan.phase().tensorflow_spec(), client_plan.tflite_graph(),
233*14675a02SAndroid Build Coastguard Worker std::move(*inputs), output_names_unused);
234*14675a02SAndroid Build Coastguard Worker engine::PlanOutcome outcome = plan_result.outcome;
235*14675a02SAndroid Build Coastguard Worker LogComputationOutcome(std::move(plan_result), phase_logger,
236*14675a02SAndroid Build Coastguard Worker run_plan_start_time, reference_time);
237*14675a02SAndroid Build Coastguard Worker return ConvertPlanOutcomeToStatus(outcome);
238*14675a02SAndroid Build Coastguard Worker }
239*14675a02SAndroid Build Coastguard Worker
240*14675a02SAndroid Build Coastguard Worker #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
241*14675a02SAndroid Build Coastguard Worker // Construct input tensors based on the values in the LocalComputeIORouter
242*14675a02SAndroid Build Coastguard Worker // message.
243*14675a02SAndroid Build Coastguard Worker auto inputs = ConstructInputsForTensorflowSpecPlan(
244*14675a02SAndroid Build Coastguard Worker client_plan.phase().local_compute(), input_dir_uri, output_dir_uri,
245*14675a02SAndroid Build Coastguard Worker input_resources);
246*14675a02SAndroid Build Coastguard Worker if (!inputs.ok()) {
247*14675a02SAndroid Build Coastguard Worker phase_logger.LogComputationInvalidArgument(
248*14675a02SAndroid Build Coastguard Worker inputs.status(), ExampleStats(), NetworkStats(), run_plan_start_time);
249*14675a02SAndroid Build Coastguard Worker return inputs.status();
250*14675a02SAndroid Build Coastguard Worker }
251*14675a02SAndroid Build Coastguard Worker engine::SimplePlanEngine plan_engine(
252*14675a02SAndroid Build Coastguard Worker example_iterator_factories, should_abort, log_manager, opstats_logger,
253*14675a02SAndroid Build Coastguard Worker &timing_config, flags->support_constant_tf_inputs());
254*14675a02SAndroid Build Coastguard Worker engine::PlanResult plan_result = plan_engine.RunPlan(
255*14675a02SAndroid Build Coastguard Worker client_plan.phase().tensorflow_spec(), client_plan.graph(),
256*14675a02SAndroid Build Coastguard Worker client_plan.tensorflow_config_proto(), std::move(*inputs),
257*14675a02SAndroid Build Coastguard Worker output_names_unused);
258*14675a02SAndroid Build Coastguard Worker engine::PlanOutcome outcome = plan_result.outcome;
259*14675a02SAndroid Build Coastguard Worker LogComputationOutcome(std::move(plan_result), phase_logger,
260*14675a02SAndroid Build Coastguard Worker run_plan_start_time, reference_time);
261*14675a02SAndroid Build Coastguard Worker return ConvertPlanOutcomeToStatus(outcome);
262*14675a02SAndroid Build Coastguard Worker #else
263*14675a02SAndroid Build Coastguard Worker return absl::InternalError("No plan engine enabled");
264*14675a02SAndroid Build Coastguard Worker #endif
265*14675a02SAndroid Build Coastguard Worker }
266*14675a02SAndroid Build Coastguard Worker } // anonymous namespace
267*14675a02SAndroid Build Coastguard Worker
RunLocalComputation(SimpleTaskEnvironment * env_deps,EventPublisher * event_publisher,LogManager * log_manager,const Flags * flags,const std::string & session_name,const std::string & plan_uri,const std::string & input_dir_uri,const std::string & output_dir_uri,const absl::flat_hash_map<std::string,std::string> & input_resources)268*14675a02SAndroid Build Coastguard Worker absl::Status RunLocalComputation(
269*14675a02SAndroid Build Coastguard Worker SimpleTaskEnvironment* env_deps, EventPublisher* event_publisher,
270*14675a02SAndroid Build Coastguard Worker LogManager* log_manager, const Flags* flags,
271*14675a02SAndroid Build Coastguard Worker const std::string& session_name, const std::string& plan_uri,
272*14675a02SAndroid Build Coastguard Worker const std::string& input_dir_uri, const std::string& output_dir_uri,
273*14675a02SAndroid Build Coastguard Worker const absl::flat_hash_map<std::string, std::string>& input_resources) {
274*14675a02SAndroid Build Coastguard Worker auto opstats_logger = engine::CreateOpStatsLogger(
275*14675a02SAndroid Build Coastguard Worker env_deps->GetBaseDir(), flags, log_manager, session_name,
276*14675a02SAndroid Build Coastguard Worker /*population_name=*/"");
277*14675a02SAndroid Build Coastguard Worker SelectorContext selector_context;
278*14675a02SAndroid Build Coastguard Worker selector_context.mutable_computation_properties()->set_session_name(
279*14675a02SAndroid Build Coastguard Worker session_name);
280*14675a02SAndroid Build Coastguard Worker LocalComputation computation = LocalComputation();
281*14675a02SAndroid Build Coastguard Worker computation.set_input_dir(input_dir_uri);
282*14675a02SAndroid Build Coastguard Worker computation.set_output_dir(output_dir_uri);
283*14675a02SAndroid Build Coastguard Worker computation.mutable_input_resource_map()->insert(input_resources.begin(),
284*14675a02SAndroid Build Coastguard Worker input_resources.end());
285*14675a02SAndroid Build Coastguard Worker *selector_context.mutable_computation_properties()->mutable_local_compute() =
286*14675a02SAndroid Build Coastguard Worker computation;
287*14675a02SAndroid Build Coastguard Worker PhaseLoggerImpl phase_logger(event_publisher, opstats_logger.get(),
288*14675a02SAndroid Build Coastguard Worker log_manager, flags);
289*14675a02SAndroid Build Coastguard Worker return RunLocalComputation(phase_logger, env_deps, log_manager,
290*14675a02SAndroid Build Coastguard Worker opstats_logger.get(), flags, plan_uri,
291*14675a02SAndroid Build Coastguard Worker input_dir_uri, output_dir_uri, input_resources,
292*14675a02SAndroid Build Coastguard Worker selector_context);
293*14675a02SAndroid Build Coastguard Worker }
294*14675a02SAndroid Build Coastguard Worker
RunLocalComputation(PhaseLogger & phase_logger,SimpleTaskEnvironment * env_deps,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,const std::string & plan_uri,const std::string & input_dir_uri,const std::string & output_dir_uri,const absl::flat_hash_map<std::string,std::string> & input_resources,const SelectorContext & selector_context)295*14675a02SAndroid Build Coastguard Worker absl::Status RunLocalComputation(
296*14675a02SAndroid Build Coastguard Worker PhaseLogger& phase_logger, SimpleTaskEnvironment* env_deps,
297*14675a02SAndroid Build Coastguard Worker LogManager* log_manager, OpStatsLogger* opstats_logger, const Flags* flags,
298*14675a02SAndroid Build Coastguard Worker const std::string& plan_uri, const std::string& input_dir_uri,
299*14675a02SAndroid Build Coastguard Worker const std::string& output_dir_uri,
300*14675a02SAndroid Build Coastguard Worker const absl::flat_hash_map<std::string, std::string>& input_resources,
301*14675a02SAndroid Build Coastguard Worker const SelectorContext& selector_context) {
302*14675a02SAndroid Build Coastguard Worker absl::Time reference_time = absl::Now();
303*14675a02SAndroid Build Coastguard Worker absl::Duration polling_period =
304*14675a02SAndroid Build Coastguard Worker absl::Milliseconds(flags->condition_polling_period_millis());
305*14675a02SAndroid Build Coastguard Worker std::function<bool()> should_abort = [env_deps, polling_period]() {
306*14675a02SAndroid Build Coastguard Worker return env_deps->ShouldAbort(absl::Now(), polling_period);
307*14675a02SAndroid Build Coastguard Worker };
308*14675a02SAndroid Build Coastguard Worker // Check if the device conditions allow running a local computation.
309*14675a02SAndroid Build Coastguard Worker if (should_abort()) {
310*14675a02SAndroid Build Coastguard Worker std::string message =
311*14675a02SAndroid Build Coastguard Worker "Device conditions not satisfied, aborting local computation";
312*14675a02SAndroid Build Coastguard Worker FCP_LOG(INFO) << message;
313*14675a02SAndroid Build Coastguard Worker phase_logger.LogTaskNotStarted(message);
314*14675a02SAndroid Build Coastguard Worker return absl::CancelledError("");
315*14675a02SAndroid Build Coastguard Worker }
316*14675a02SAndroid Build Coastguard Worker // Local compute plans can use example iterators from the
317*14675a02SAndroid Build Coastguard Worker // SimpleTaskEnvironment and those reading the OpStats DB.
318*14675a02SAndroid Build Coastguard Worker opstats::OpStatsExampleIteratorFactory opstats_example_iterator_factory(
319*14675a02SAndroid Build Coastguard Worker opstats_logger, log_manager,
320*14675a02SAndroid Build Coastguard Worker flags->opstats_last_successful_contribution_criteria());
321*14675a02SAndroid Build Coastguard Worker std::unique_ptr<engine::ExampleIteratorFactory> env_example_iterator_factory =
322*14675a02SAndroid Build Coastguard Worker CreateSimpleTaskEnvironmentIteratorFactory(env_deps, selector_context);
323*14675a02SAndroid Build Coastguard Worker std::vector<engine::ExampleIteratorFactory*> example_iterator_factories{
324*14675a02SAndroid Build Coastguard Worker &opstats_example_iterator_factory, env_example_iterator_factory.get()};
325*14675a02SAndroid Build Coastguard Worker
326*14675a02SAndroid Build Coastguard Worker fcp::client::InterruptibleRunner::TimingConfig timing_config = {
327*14675a02SAndroid Build Coastguard Worker .polling_period = polling_period,
328*14675a02SAndroid Build Coastguard Worker .graceful_shutdown_period = absl::Milliseconds(
329*14675a02SAndroid Build Coastguard Worker flags->tf_execution_teardown_grace_period_millis()),
330*14675a02SAndroid Build Coastguard Worker .extended_shutdown_period = absl::Milliseconds(
331*14675a02SAndroid Build Coastguard Worker flags->tf_execution_teardown_extended_period_millis()),
332*14675a02SAndroid Build Coastguard Worker };
333*14675a02SAndroid Build Coastguard Worker
334*14675a02SAndroid Build Coastguard Worker absl::Time run_plan_start_time = absl::Now();
335*14675a02SAndroid Build Coastguard Worker phase_logger.LogComputationStarted();
336*14675a02SAndroid Build Coastguard Worker
337*14675a02SAndroid Build Coastguard Worker absl::StatusOr<std::string> plan_str = fcp::ReadFileToString(plan_uri);
338*14675a02SAndroid Build Coastguard Worker if (!plan_str.ok()) {
339*14675a02SAndroid Build Coastguard Worker phase_logger.LogComputationIOError(plan_str.status(), ExampleStats(),
340*14675a02SAndroid Build Coastguard Worker NetworkStats(), run_plan_start_time);
341*14675a02SAndroid Build Coastguard Worker return plan_str.status();
342*14675a02SAndroid Build Coastguard Worker }
343*14675a02SAndroid Build Coastguard Worker
344*14675a02SAndroid Build Coastguard Worker ClientOnlyPlan plan;
345*14675a02SAndroid Build Coastguard Worker if (!plan.ParseFromString(*plan_str)) {
346*14675a02SAndroid Build Coastguard Worker absl::Status error_status =
347*14675a02SAndroid Build Coastguard Worker absl::InvalidArgumentError("could not parse received plan");
348*14675a02SAndroid Build Coastguard Worker phase_logger.LogComputationInvalidArgument(
349*14675a02SAndroid Build Coastguard Worker error_status, ExampleStats(), NetworkStats(), run_plan_start_time);
350*14675a02SAndroid Build Coastguard Worker return error_status;
351*14675a02SAndroid Build Coastguard Worker }
352*14675a02SAndroid Build Coastguard Worker
353*14675a02SAndroid Build Coastguard Worker std::vector<std::string> output_names;
354*14675a02SAndroid Build Coastguard Worker std::vector<tensorflow::Tensor> output_tensors;
355*14675a02SAndroid Build Coastguard Worker return RunPlanWithTensorflowSpec(
356*14675a02SAndroid Build Coastguard Worker phase_logger, example_iterator_factories, should_abort, log_manager,
357*14675a02SAndroid Build Coastguard Worker opstats_logger, flags, plan, input_dir_uri, output_dir_uri,
358*14675a02SAndroid Build Coastguard Worker input_resources, timing_config, run_plan_start_time, reference_time);
359*14675a02SAndroid Build Coastguard Worker }
360*14675a02SAndroid Build Coastguard Worker
361*14675a02SAndroid Build Coastguard Worker } // namespace client
362*14675a02SAndroid Build Coastguard Worker } // namespace fcp
363