xref: /aosp_15_r20/external/federated-compute/fcp/client/engine/plan_engine_helpers.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef FCP_CLIENT_ENGINE_PLAN_ENGINE_HELPERS_H_
17 #define FCP_CLIENT_ENGINE_PLAN_ENGINE_HELPERS_H_
18 
19 #include <atomic>
20 #include <functional>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/status/status.h"
27 #include "absl/status/statusor.h"
28 #include "fcp/base/monitoring.h"
29 #include "fcp/client/engine/common.h"
30 #include "fcp/client/engine/example_iterator_factory.h"
31 #include "fcp/client/event_publisher.h"
32 #include "fcp/client/flags.h"
33 #include "fcp/client/log_manager.h"
34 #include "fcp/client/opstats/opstats_logger.h"
35 #include "fcp/client/simple_task_environment.h"
36 #include "fcp/tensorflow/external_dataset.h"
37 #include "fcp/tensorflow/host_object.h"
38 #include "tensorflow/core/framework/tensor.h"
39 
40 // On Error Handling
41 // Calls in the engine are assumed to either
42 // 1. be successful (Status::OK)
43 // 2. fail with an "expected" error -> handle gracefully - log error, tell the
44 //    environment (via finish), return
45 // 3. encounter "unexpected" errors; when originating inside the engine or in
46 //    native code in the environment, or from java, crash.
47 // While this type of tristate error handling is easy in Java (success, checked,
48 // unchecked exceptions), it isn't in C++, hence we adopt the following
49 // convention for control flow/error handling inside the engine:
50 // - all functions in the plan engine downstream of runPhase() that can fail
51 //   must return a Status with one of the following codes: INTERNAL_ERROR,
52 //   CANCELLED, INVALID_ARGUMENT, OK. Only on OK will normal execution continue,
53 //   otherwise return up to the top level (runPhase). Once at the top level,
54 //   those error codes will be handled as follows:
55 //   a) CANCELLED -> report INTERRUPTED to env
56 //   b) INTERNAL_ERROR/INVALID_ARGUMENT -> report ERROR to env
57 //   c) OK -> report COMPLETED to env
58 //   For all status codes, the TaskRetry returned from the env is returned.
59 // - utility functions outside of the engine will also use Status/StatusOr, but
60 //   may use other error codes (e.g. the TensorFlowWrapper or ExampleIterator
61 //   use OUT_OF_RANGE).
62 // Return error handling is beautiful, I use this macro:
63 // #1: FCP_ENGINE_RETURN_IF_ERROR(...): Return if the Status code is not OK,
64 // else continue.
65 
66 namespace fcp {
67 namespace client {
68 namespace engine {
69 namespace internal {
AsStatus(absl::Status status)70 inline absl::Status AsStatus(absl::Status status) { return status; }
71 }  // namespace internal
72 
73 // Macro to return the provided Status (or Status contained in StatusOr) if a
74 // call to ok() fails.
75 #define FCP_ENGINE_RETURN_IF_ERROR(status_or_statusor_expr)                 \
76   do {                                                                      \
77     const absl::Status __status =                                           \
78         ::fcp::client::engine::internal::AsStatus(status_or_statusor_expr); \
79     if (ABSL_PREDICT_FALSE(__status.code() != absl::StatusCode::kOk)) {     \
80       return __status;                                                      \
81     }                                                                       \
82   } while (0)
83 
84 // Tracks whether any example iterator encountered an error during the
85 // computation (a single computation may use multiple iterators), either during
86 // creation of the iterator or during one of the iterations.
87 // This class is thread-safe.
88 class ExampleIteratorStatus {
89  public:
90   void SetStatus(absl::Status status) ABSL_LOCKS_EXCLUDED(mu_);
91   absl::Status GetStatus() ABSL_LOCKS_EXCLUDED(mu_);
92 
93  private:
94   absl::Status status_ ABSL_GUARDED_BY(mu_) = absl::OkStatus();
95   mutable absl::Mutex mu_;
96 };
97 
98 // A class to iterate over a given example iterator.
99 class DatasetIterator : public ExternalDatasetIterator {
100  public:
101   DatasetIterator(std::unique_ptr<ExampleIterator> example_iterator,
102                   opstats::OpStatsLogger* opstats_logger,
103                   std::atomic<int>* total_example_count,
104                   std::atomic<int64_t>* total_example_size_bytes,
105                   ExampleIteratorStatus* example_iterator_status,
106                   const std::string& collection_uri, bool collect_stats);
107   ~DatasetIterator() override;
108 
109   // Returns the next entry from the dataset.
110   absl::StatusOr<std::string> GetNext() final;
111 
112  private:
113   std::unique_ptr<ExampleIterator> example_iterator_
114       ABSL_GUARDED_BY(iterator_lock_);
115   opstats::OpStatsLogger* opstats_logger_;
116   absl::Time iterator_start_time_;
117   // Example stats across all datasets.
118   std::atomic<int>* total_example_count_;
119   std::atomic<int64_t>* total_example_size_bytes_;
120   ExampleIteratorStatus* example_iterator_status_;
121   // Example stats only for this dataset.
122   std::atomic<int> example_count_;
123   std::atomic<int64_t> example_size_bytes_;
124   const std::string collection_uri_;
125   bool iterator_finished_ ABSL_GUARDED_BY(iterator_lock_);
126   const bool collect_stats_;
127   absl::Mutex iterator_lock_;
128 };
129 
130 // Sets up a ExternalDatasetProvider that is registered with the global
131 // HostObjectRegistry. Adds a tensor representing the HostObjectRegistration
132 // token to the input tensors with the provided dataset_token_tensor_name key.
133 //
134 // For each example query issued by the plan at runtime, the given
135 // `example_iterator_factories` parameter will be iterated and the first
136 // iterator factory that can handle the given query will be used to create the
137 // example iterator to handle that query.
138 HostObjectRegistration AddDatasetTokenToInputs(
139     std::vector<ExampleIteratorFactory*> example_iterator_factories,
140     ::fcp::client::opstats::OpStatsLogger* opstats_logger,
141     std::vector<std::pair<std::string, tensorflow::Tensor>>* inputs,
142     const std::string& dataset_token_tensor_name,
143     std::atomic<int>* total_example_count,
144     std::atomic<int64_t>* total_example_size_bytes,
145     ExampleIteratorStatus* example_iterator_status);
146 
147 // Sets up an ExternalDatasetProvider that is registered with the global
148 // HostObjectRegistry. Adds a string representing the HostObjectRegistration
149 // token to the map of input tensor name and values with the provided
150 // dataset_token_tensor_name key.
151 //
152 // For each example query issued by the plan at runtime, the given
153 // `example_iterator_factories` parameter will be iterated and the first
154 // iterator factory that can handle the given query will be used to create the
155 // example iterator to handle that query.
156 HostObjectRegistration AddDatasetTokenToInputsForTfLite(
157     std::vector<ExampleIteratorFactory*> example_iterator_factories,
158     ::fcp::client::opstats::OpStatsLogger* opstats_logger,
159     absl::flat_hash_map<std::string, std::string>* inputs,
160     const std::string& dataset_token_tensor_name,
161     std::atomic<int>* total_example_count,
162     std::atomic<int64_t>* total_example_size_bytes,
163     ExampleIteratorStatus* example_iterator_status);
164 
165 // If opstats is enabled, this method attempts to create an opstats logger
166 // backed by a database within base_dir and prepares to record information for a
167 // training run with the provided session and population names. If there is an
168 // error initializing the db or opstats is disabled, creates a no-op logger.
169 std::unique_ptr<::fcp::client::opstats::OpStatsLogger> CreateOpStatsLogger(
170     const std::string& base_dir, const Flags* flags, LogManager* log_manager,
171     const std::string& session_name, const std::string& population_name);
172 
173 // Utility for creating a PlanResult when an `INVALID_ARGUMENT` TensorFlow error
174 // was encountered, disambiguating between generic TF errors and TF errors that
175 // were likely root-caused by an earlier example iterator error.
176 PlanResult CreateComputationErrorPlanResult(
177     absl::Status example_iterator_status,
178     absl::Status computation_error_status);
179 
180 // Finds a suitable example iterator factory out of provided factories based on
181 // the provided selector.
182 ExampleIteratorFactory* FindExampleIteratorFactory(
183     const google::internal::federated::plan::ExampleSelector& selector,
184     std::vector<ExampleIteratorFactory*> example_iterator_factories);
185 
186 }  // namespace engine
187 }  // namespace client
188 }  // namespace fcp
189 
190 #endif  // FCP_CLIENT_ENGINE_PLAN_ENGINE_HELPERS_H_
191