1 /*
2 * Copyright 2020 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "fcp/client/lc_runner.h"
17
18 #include <functional>
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/status/status.h"
27 #include "absl/status/statusor.h"
28 #include "absl/time/time.h"
29 #include "fcp/base/monitoring.h"
30 #include "fcp/base/platform.h"
31 #include "fcp/client/engine/example_iterator_factory.h"
32 #include "fcp/client/engine/plan_engine_helpers.h"
33
34 #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
35 #include "fcp/client/engine/simple_plan_engine.h"
36 #endif
37
38 #include "fcp/client/engine/tflite_plan_engine.h"
39 #include "fcp/client/opstats/opstats_example_store.h"
40 #include "fcp/client/phase_logger_impl.h"
41 #include "fcp/client/selector_context.pb.h"
42 #include "fcp/protos/plan.pb.h"
43 #include "tensorflow/core/framework/tensor.h"
44 #include "tensorflow/core/framework/tensor.pb.h"
45 #include "tensorflow/core/framework/tensor_shape.pb.h"
46 #include "tensorflow/core/protobuf/struct.pb.h"
47
48 namespace fcp {
49 namespace client {
50
51 using ::fcp::client::opstats::OpStatsLogger;
52 using ::google::internal::federated::plan::ClientOnlyPlan;
53 using ::google::internal::federated::plan::LocalComputeIORouter;
54
55 using TfLiteInputs = absl::flat_hash_map<std::string, std::string>;
56 using TfMobileInputs = std::vector<std::pair<std::string, tensorflow::Tensor>>;
57
58 namespace {
59 #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
60 absl::StatusOr<std::unique_ptr<TfMobileInputs>>
ConstructInputsForTensorflowSpecPlan(const LocalComputeIORouter & local_compute,const std::string & input_dir_uri,const std::string & output_dir_uri,const absl::flat_hash_map<std::string,std::string> & input_resources)61 ConstructInputsForTensorflowSpecPlan(
62 const LocalComputeIORouter& local_compute, const std::string& input_dir_uri,
63 const std::string& output_dir_uri,
64 const absl::flat_hash_map<std::string, std::string>& input_resources) {
65 auto inputs = std::make_unique<
66 std::vector<std::pair<std::string, tensorflow::Tensor>>>();
67 if (local_compute.has_multiple_input_resources()) {
68 if (!input_dir_uri.empty()) {
69 return absl::InvalidArgumentError(
70 "Both input dir and input resources are provided.");
71 }
72 auto input_resource_tensor_name_map =
73 local_compute.multiple_input_resources()
74 .input_resource_tensor_name_map();
75 for (const auto& resource : input_resources) {
76 tensorflow::Tensor resource_tensor(tensorflow::DT_STRING, {});
77 resource_tensor.scalar<tensorflow::tstring>()() = resource.second;
78 if (!input_resource_tensor_name_map.contains(resource.first)) {
79 return absl::InvalidArgumentError(
80 absl::StrCat("User provided input resource:", resource.first,
81 " is missing in LocalComputeIORouter."));
82 }
83 std::string tensor_name = input_resource_tensor_name_map[resource.first];
84 inputs->push_back({tensor_name, resource_tensor});
85 }
86 } else {
87 tensorflow::Tensor input_dirpath(tensorflow::DT_STRING, {});
88 input_dirpath.scalar<tensorflow::tstring>()() = input_dir_uri;
89 inputs->push_back({local_compute.input_dir_tensor_name(), input_dirpath});
90 }
91 tensorflow::Tensor output_dirpath(tensorflow::DT_STRING, {});
92 output_dirpath.scalar<tensorflow::tstring>()() = output_dir_uri;
93 inputs->push_back({local_compute.output_dir_tensor_name(), output_dirpath});
94 return inputs;
95 }
96 #endif
97
ConstructInputsForTFLitePlan(const LocalComputeIORouter & local_compute,const std::string & input_dir_uri,const std::string & output_dir_uri,const absl::flat_hash_map<std::string,std::string> & input_resources)98 absl::StatusOr<std::unique_ptr<TfLiteInputs>> ConstructInputsForTFLitePlan(
99 const LocalComputeIORouter& local_compute, const std::string& input_dir_uri,
100 const std::string& output_dir_uri,
101 const absl::flat_hash_map<std::string, std::string>& input_resources) {
102 auto inputs = std::make_unique<TfLiteInputs>();
103 if (local_compute.has_multiple_input_resources()) {
104 if (!input_dir_uri.empty()) {
105 return absl::InvalidArgumentError(
106 "Both input dir and input resources are provided.");
107 }
108 auto input_resource_tensor_name_map =
109 local_compute.multiple_input_resources()
110 .input_resource_tensor_name_map();
111 for (const auto& resource : input_resources) {
112 if (!input_resource_tensor_name_map.contains(resource.first)) {
113 // If the user provided more input resources than required in the
114 // LocalComputeIORouter, we simply continue without throwing an error.
115 // In this way, the user could update their scheduling logic separately
116 // from their local computation definitions.
117 continue;
118 }
119 std::string tensor_name = input_resource_tensor_name_map[resource.first];
120 (*inputs)[tensor_name] = resource.second;
121 }
122 } else {
123 (*inputs)[local_compute.input_dir_tensor_name()] = input_dir_uri;
124 }
125 (*inputs)[local_compute.output_dir_tensor_name()] = output_dir_uri;
126 return inputs;
127 }
128
LogComputationOutcome(engine::PlanResult plan_result,PhaseLogger & phase_logger,absl::Time run_plan_start_time,absl::Time reference_time)129 void LogComputationOutcome(engine::PlanResult plan_result,
130 PhaseLogger& phase_logger,
131 absl::Time run_plan_start_time,
132 absl::Time reference_time) {
133 switch (plan_result.outcome) {
134 case engine::PlanOutcome::kSuccess:
135 phase_logger.LogComputationCompleted(plan_result.example_stats,
136 NetworkStats(), run_plan_start_time,
137 reference_time);
138 break;
139 case engine::PlanOutcome::kInterrupted:
140 phase_logger.LogComputationInterrupted(
141 plan_result.original_status, plan_result.example_stats,
142 NetworkStats(), run_plan_start_time, reference_time);
143 break;
144 case engine::PlanOutcome::kInvalidArgument:
145 phase_logger.LogComputationInvalidArgument(
146 plan_result.original_status, plan_result.example_stats,
147 NetworkStats(), run_plan_start_time);
148 break;
149 case engine::PlanOutcome::kTensorflowError:
150 phase_logger.LogComputationTensorflowError(
151 std::move(plan_result.original_status), plan_result.example_stats,
152 NetworkStats(), run_plan_start_time, reference_time);
153 break;
154 case engine::PlanOutcome::kExampleIteratorError:
155 phase_logger.LogComputationExampleIteratorError(
156 plan_result.original_status, plan_result.example_stats,
157 NetworkStats(), run_plan_start_time);
158 break;
159 }
160 }
161
162 // Creates an ExampleIteratorFactory that routes queries to the
163 // SimpleTaskEnvironment::CreateExampleIterator() method.
164 std::unique_ptr<engine::ExampleIteratorFactory>
CreateSimpleTaskEnvironmentIteratorFactory(SimpleTaskEnvironment * task_env,const SelectorContext & selector_context)165 CreateSimpleTaskEnvironmentIteratorFactory(
166 SimpleTaskEnvironment* task_env, const SelectorContext& selector_context) {
167 return std::make_unique<engine::FunctionalExampleIteratorFactory>(
168 /*can_handle_func=*/
169 [](const google::internal::federated::plan::ExampleSelector&) {
170 // The SimpleTaskEnvironment-based ExampleIteratorFactory should
171 // be the catch-all factory that is able to handle all queries
172 // that no other ExampleIteratorFactory is able to handle.
173 return true;
174 },
175 /*create_iterator_func=*/
176 [task_env, selector_context](
177 const google::internal::federated::plan::ExampleSelector&
178 example_selector) {
179 return task_env->CreateExampleIterator(example_selector,
180 selector_context);
181 },
182 /*should_collect_stats=*/true);
183 }
184
RunPlanWithTensorflowSpec(PhaseLogger & phase_logger,std::vector<engine::ExampleIteratorFactory * > example_iterator_factories,std::function<bool ()> should_abort,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,const ClientOnlyPlan & client_plan,const std::string & input_dir_uri,const std::string & output_dir_uri,const absl::flat_hash_map<std::string,std::string> & input_resources,const fcp::client::InterruptibleRunner::TimingConfig & timing_config,const absl::Time run_plan_start_time,const absl::Time reference_time)185 absl::Status RunPlanWithTensorflowSpec(
186 PhaseLogger& phase_logger,
187 std::vector<engine::ExampleIteratorFactory*> example_iterator_factories,
188 std::function<bool()> should_abort, LogManager* log_manager,
189 OpStatsLogger* opstats_logger, const Flags* flags,
190 const ClientOnlyPlan& client_plan, const std::string& input_dir_uri,
191 const std::string& output_dir_uri,
192 const absl::flat_hash_map<std::string, std::string>& input_resources,
193 const fcp::client::InterruptibleRunner::TimingConfig& timing_config,
194 const absl::Time run_plan_start_time, const absl::Time reference_time) {
195 // Check that this is a TensorflowSpec-based plan for local computation.
196 if (!client_plan.phase().has_tensorflow_spec()) {
197 absl::Status error_status =
198 absl::InvalidArgumentError("Plan without TensorflowSpec");
199 phase_logger.LogComputationInvalidArgument(
200 error_status, ExampleStats(), NetworkStats(), run_plan_start_time);
201 return error_status;
202 }
203 if (!client_plan.phase().has_local_compute()) {
204 absl::Status error_status =
205 absl::InvalidArgumentError("Invalid TensorflowSpec-based plan");
206 phase_logger.LogComputationInvalidArgument(
207 error_status, ExampleStats(), NetworkStats(), run_plan_start_time);
208 return error_status;
209 }
210
211 // Run plan
212 std::vector<std::string> output_names_unused;
213
214 if (!client_plan.tflite_graph().empty()) {
215 log_manager->LogDiag(
216 ProdDiagCode::BACKGROUND_TRAINING_TFLITE_MODEL_INCLUDED);
217 }
218
219 if (flags->use_tflite_training() && !client_plan.tflite_graph().empty()) {
220 auto inputs = ConstructInputsForTFLitePlan(
221 client_plan.phase().local_compute(), input_dir_uri, output_dir_uri,
222 input_resources);
223 if (!inputs.ok()) {
224 phase_logger.LogComputationInvalidArgument(
225 inputs.status(), ExampleStats(), NetworkStats(), run_plan_start_time);
226 return inputs.status();
227 }
228 engine::TfLitePlanEngine plan_engine(example_iterator_factories,
229 should_abort, log_manager,
230 opstats_logger, flags, &timing_config);
231 engine::PlanResult plan_result = plan_engine.RunPlan(
232 client_plan.phase().tensorflow_spec(), client_plan.tflite_graph(),
233 std::move(*inputs), output_names_unused);
234 engine::PlanOutcome outcome = plan_result.outcome;
235 LogComputationOutcome(std::move(plan_result), phase_logger,
236 run_plan_start_time, reference_time);
237 return ConvertPlanOutcomeToStatus(outcome);
238 }
239
240 #ifdef FCP_CLIENT_SUPPORT_TFMOBILE
241 // Construct input tensors based on the values in the LocalComputeIORouter
242 // message.
243 auto inputs = ConstructInputsForTensorflowSpecPlan(
244 client_plan.phase().local_compute(), input_dir_uri, output_dir_uri,
245 input_resources);
246 if (!inputs.ok()) {
247 phase_logger.LogComputationInvalidArgument(
248 inputs.status(), ExampleStats(), NetworkStats(), run_plan_start_time);
249 return inputs.status();
250 }
251 engine::SimplePlanEngine plan_engine(
252 example_iterator_factories, should_abort, log_manager, opstats_logger,
253 &timing_config, flags->support_constant_tf_inputs());
254 engine::PlanResult plan_result = plan_engine.RunPlan(
255 client_plan.phase().tensorflow_spec(), client_plan.graph(),
256 client_plan.tensorflow_config_proto(), std::move(*inputs),
257 output_names_unused);
258 engine::PlanOutcome outcome = plan_result.outcome;
259 LogComputationOutcome(std::move(plan_result), phase_logger,
260 run_plan_start_time, reference_time);
261 return ConvertPlanOutcomeToStatus(outcome);
262 #else
263 return absl::InternalError("No plan engine enabled");
264 #endif
265 }
266 } // anonymous namespace
267
RunLocalComputation(SimpleTaskEnvironment * env_deps,EventPublisher * event_publisher,LogManager * log_manager,const Flags * flags,const std::string & session_name,const std::string & plan_uri,const std::string & input_dir_uri,const std::string & output_dir_uri,const absl::flat_hash_map<std::string,std::string> & input_resources)268 absl::Status RunLocalComputation(
269 SimpleTaskEnvironment* env_deps, EventPublisher* event_publisher,
270 LogManager* log_manager, const Flags* flags,
271 const std::string& session_name, const std::string& plan_uri,
272 const std::string& input_dir_uri, const std::string& output_dir_uri,
273 const absl::flat_hash_map<std::string, std::string>& input_resources) {
274 auto opstats_logger = engine::CreateOpStatsLogger(
275 env_deps->GetBaseDir(), flags, log_manager, session_name,
276 /*population_name=*/"");
277 SelectorContext selector_context;
278 selector_context.mutable_computation_properties()->set_session_name(
279 session_name);
280 LocalComputation computation = LocalComputation();
281 computation.set_input_dir(input_dir_uri);
282 computation.set_output_dir(output_dir_uri);
283 computation.mutable_input_resource_map()->insert(input_resources.begin(),
284 input_resources.end());
285 *selector_context.mutable_computation_properties()->mutable_local_compute() =
286 computation;
287 PhaseLoggerImpl phase_logger(event_publisher, opstats_logger.get(),
288 log_manager, flags);
289 return RunLocalComputation(phase_logger, env_deps, log_manager,
290 opstats_logger.get(), flags, plan_uri,
291 input_dir_uri, output_dir_uri, input_resources,
292 selector_context);
293 }
294
RunLocalComputation(PhaseLogger & phase_logger,SimpleTaskEnvironment * env_deps,LogManager * log_manager,OpStatsLogger * opstats_logger,const Flags * flags,const std::string & plan_uri,const std::string & input_dir_uri,const std::string & output_dir_uri,const absl::flat_hash_map<std::string,std::string> & input_resources,const SelectorContext & selector_context)295 absl::Status RunLocalComputation(
296 PhaseLogger& phase_logger, SimpleTaskEnvironment* env_deps,
297 LogManager* log_manager, OpStatsLogger* opstats_logger, const Flags* flags,
298 const std::string& plan_uri, const std::string& input_dir_uri,
299 const std::string& output_dir_uri,
300 const absl::flat_hash_map<std::string, std::string>& input_resources,
301 const SelectorContext& selector_context) {
302 absl::Time reference_time = absl::Now();
303 absl::Duration polling_period =
304 absl::Milliseconds(flags->condition_polling_period_millis());
305 std::function<bool()> should_abort = [env_deps, polling_period]() {
306 return env_deps->ShouldAbort(absl::Now(), polling_period);
307 };
308 // Check if the device conditions allow running a local computation.
309 if (should_abort()) {
310 std::string message =
311 "Device conditions not satisfied, aborting local computation";
312 FCP_LOG(INFO) << message;
313 phase_logger.LogTaskNotStarted(message);
314 return absl::CancelledError("");
315 }
316 // Local compute plans can use example iterators from the
317 // SimpleTaskEnvironment and those reading the OpStats DB.
318 opstats::OpStatsExampleIteratorFactory opstats_example_iterator_factory(
319 opstats_logger, log_manager,
320 flags->opstats_last_successful_contribution_criteria());
321 std::unique_ptr<engine::ExampleIteratorFactory> env_example_iterator_factory =
322 CreateSimpleTaskEnvironmentIteratorFactory(env_deps, selector_context);
323 std::vector<engine::ExampleIteratorFactory*> example_iterator_factories{
324 &opstats_example_iterator_factory, env_example_iterator_factory.get()};
325
326 fcp::client::InterruptibleRunner::TimingConfig timing_config = {
327 .polling_period = polling_period,
328 .graceful_shutdown_period = absl::Milliseconds(
329 flags->tf_execution_teardown_grace_period_millis()),
330 .extended_shutdown_period = absl::Milliseconds(
331 flags->tf_execution_teardown_extended_period_millis()),
332 };
333
334 absl::Time run_plan_start_time = absl::Now();
335 phase_logger.LogComputationStarted();
336
337 absl::StatusOr<std::string> plan_str = fcp::ReadFileToString(plan_uri);
338 if (!plan_str.ok()) {
339 phase_logger.LogComputationIOError(plan_str.status(), ExampleStats(),
340 NetworkStats(), run_plan_start_time);
341 return plan_str.status();
342 }
343
344 ClientOnlyPlan plan;
345 if (!plan.ParseFromString(*plan_str)) {
346 absl::Status error_status =
347 absl::InvalidArgumentError("could not parse received plan");
348 phase_logger.LogComputationInvalidArgument(
349 error_status, ExampleStats(), NetworkStats(), run_plan_start_time);
350 return error_status;
351 }
352
353 std::vector<std::string> output_names;
354 std::vector<tensorflow::Tensor> output_tensors;
355 return RunPlanWithTensorflowSpec(
356 phase_logger, example_iterator_factories, should_abort, log_manager,
357 opstats_logger, flags, plan, input_dir_uri, output_dir_uri,
358 input_resources, timing_config, run_plan_start_time, reference_time);
359 }
360
361 } // namespace client
362 } // namespace fcp
363