1 /*
2 * Copyright 2019 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifndef FCP_CLIENT_ENGINE_PLAN_ENGINE_HELPERS_H_
17 #define FCP_CLIENT_ENGINE_PLAN_ENGINE_HELPERS_H_
18
19 #include <atomic>
20 #include <functional>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "absl/status/status.h"
27 #include "absl/status/statusor.h"
28 #include "fcp/base/monitoring.h"
29 #include "fcp/client/engine/common.h"
30 #include "fcp/client/engine/example_iterator_factory.h"
31 #include "fcp/client/event_publisher.h"
32 #include "fcp/client/flags.h"
33 #include "fcp/client/log_manager.h"
34 #include "fcp/client/opstats/opstats_logger.h"
35 #include "fcp/client/simple_task_environment.h"
36 #include "fcp/tensorflow/external_dataset.h"
37 #include "fcp/tensorflow/host_object.h"
38 #include "tensorflow/core/framework/tensor.h"
39
40 // On Error Handling
41 // Calls in the engine are assumed to either
42 // 1. be successful (Status::OK)
43 // 2. fail with an "expected" error -> handle gracefully - log error, tell the
44 // environment (via finish), return
45 // 3. encounter "unexpected" errors; when originating inside the engine or in
46 // native code in the environment, or from java, crash.
47 // While this type of tristate error handling is easy in Java (success, checked,
48 // unchecked exceptions), it isn't in C++, hence we adopt the following
49 // convention for control flow/error handling inside the engine:
50 // - all functions in the plan engine downstream of runPhase() that can fail
51 // must return a Status with one of the following codes: INTERNAL_ERROR,
52 // CANCELLED, INVALID_ARGUMENT, OK. Only on OK will normal execution continue,
53 // otherwise return up to the top level (runPhase). Once at the top level,
54 // those error codes will be handled as follows:
55 // a) CANCELLED -> report INTERRUPTED to env
56 // b) INTERNAL_ERROR/INVALID_ARGUMENT -> report ERROR to env
57 // c) OK -> report COMPLETED to env
58 // For all status codes, the TaskRetry returned from the env is returned.
59 // - utility functions outside of the engine will also use Status/StatusOr, but
60 // may use other error codes (e.g. the TensorFlowWrapper or ExampleIterator
61 // use OUT_OF_RANGE).
62 // Return error handling is beautiful, I use this macro:
63 // #1: FCP_ENGINE_RETURN_IF_ERROR(...): Return if the Status code is not OK,
64 // else continue.
65
66 namespace fcp {
67 namespace client {
68 namespace engine {
69 namespace internal {
AsStatus(absl::Status status)70 inline absl::Status AsStatus(absl::Status status) { return status; }
71 } // namespace internal
72
73 // Macro to return the provided Status (or Status contained in StatusOr) if a
74 // call to ok() fails.
75 #define FCP_ENGINE_RETURN_IF_ERROR(status_or_statusor_expr) \
76 do { \
77 const absl::Status __status = \
78 ::fcp::client::engine::internal::AsStatus(status_or_statusor_expr); \
79 if (ABSL_PREDICT_FALSE(__status.code() != absl::StatusCode::kOk)) { \
80 return __status; \
81 } \
82 } while (0)
83
84 // Tracks whether any example iterator encountered an error during the
85 // computation (a single computation may use multiple iterators), either during
86 // creation of the iterator or during one of the iterations.
87 // This class is thread-safe.
88 class ExampleIteratorStatus {
89 public:
90 void SetStatus(absl::Status status) ABSL_LOCKS_EXCLUDED(mu_);
91 absl::Status GetStatus() ABSL_LOCKS_EXCLUDED(mu_);
92
93 private:
94 absl::Status status_ ABSL_GUARDED_BY(mu_) = absl::OkStatus();
95 mutable absl::Mutex mu_;
96 };
97
98 // A class to iterate over a given example iterator.
99 class DatasetIterator : public ExternalDatasetIterator {
100 public:
101 DatasetIterator(std::unique_ptr<ExampleIterator> example_iterator,
102 opstats::OpStatsLogger* opstats_logger,
103 std::atomic<int>* total_example_count,
104 std::atomic<int64_t>* total_example_size_bytes,
105 ExampleIteratorStatus* example_iterator_status,
106 const std::string& collection_uri, bool collect_stats);
107 ~DatasetIterator() override;
108
109 // Returns the next entry from the dataset.
110 absl::StatusOr<std::string> GetNext() final;
111
112 private:
113 std::unique_ptr<ExampleIterator> example_iterator_
114 ABSL_GUARDED_BY(iterator_lock_);
115 opstats::OpStatsLogger* opstats_logger_;
116 absl::Time iterator_start_time_;
117 // Example stats across all datasets.
118 std::atomic<int>* total_example_count_;
119 std::atomic<int64_t>* total_example_size_bytes_;
120 ExampleIteratorStatus* example_iterator_status_;
121 // Example stats only for this dataset.
122 std::atomic<int> example_count_;
123 std::atomic<int64_t> example_size_bytes_;
124 const std::string collection_uri_;
125 bool iterator_finished_ ABSL_GUARDED_BY(iterator_lock_);
126 const bool collect_stats_;
127 absl::Mutex iterator_lock_;
128 };
129
130 // Sets up a ExternalDatasetProvider that is registered with the global
131 // HostObjectRegistry. Adds a tensor representing the HostObjectRegistration
132 // token to the input tensors with the provided dataset_token_tensor_name key.
133 //
134 // For each example query issued by the plan at runtime, the given
135 // `example_iterator_factories` parameter will be iterated and the first
136 // iterator factory that can handle the given query will be used to create the
137 // example iterator to handle that query.
138 HostObjectRegistration AddDatasetTokenToInputs(
139 std::vector<ExampleIteratorFactory*> example_iterator_factories,
140 ::fcp::client::opstats::OpStatsLogger* opstats_logger,
141 std::vector<std::pair<std::string, tensorflow::Tensor>>* inputs,
142 const std::string& dataset_token_tensor_name,
143 std::atomic<int>* total_example_count,
144 std::atomic<int64_t>* total_example_size_bytes,
145 ExampleIteratorStatus* example_iterator_status);
146
147 // Sets up an ExternalDatasetProvider that is registered with the global
148 // HostObjectRegistry. Adds a string representing the HostObjectRegistration
149 // token to the map of input tensor name and values with the provided
150 // dataset_token_tensor_name key.
151 //
152 // For each example query issued by the plan at runtime, the given
153 // `example_iterator_factories` parameter will be iterated and the first
154 // iterator factory that can handle the given query will be used to create the
155 // example iterator to handle that query.
156 HostObjectRegistration AddDatasetTokenToInputsForTfLite(
157 std::vector<ExampleIteratorFactory*> example_iterator_factories,
158 ::fcp::client::opstats::OpStatsLogger* opstats_logger,
159 absl::flat_hash_map<std::string, std::string>* inputs,
160 const std::string& dataset_token_tensor_name,
161 std::atomic<int>* total_example_count,
162 std::atomic<int64_t>* total_example_size_bytes,
163 ExampleIteratorStatus* example_iterator_status);
164
165 // If opstats is enabled, this method attempts to create an opstats logger
166 // backed by a database within base_dir and prepares to record information for a
167 // training run with the provided session and population names. If there is an
168 // error initializing the db or opstats is disabled, creates a no-op logger.
169 std::unique_ptr<::fcp::client::opstats::OpStatsLogger> CreateOpStatsLogger(
170 const std::string& base_dir, const Flags* flags, LogManager* log_manager,
171 const std::string& session_name, const std::string& population_name);
172
173 // Utility for creating a PlanResult when an `INVALID_ARGUMENT` TensorFlow error
174 // was encountered, disambiguating between generic TF errors and TF errors that
175 // were likely root-caused by an earlier example iterator error.
176 PlanResult CreateComputationErrorPlanResult(
177 absl::Status example_iterator_status,
178 absl::Status computation_error_status);
179
180 // Finds a suitable example iterator factory out of provided factories based on
181 // the provided selector.
182 ExampleIteratorFactory* FindExampleIteratorFactory(
183 const google::internal::federated::plan::ExampleSelector& selector,
184 std::vector<ExampleIteratorFactory*> example_iterator_factories);
185
186 } // namespace engine
187 } // namespace client
188 } // namespace fcp
189
190 #endif // FCP_CLIENT_ENGINE_PLAN_ENGINE_HELPERS_H_
191