1 /*
2 * Copyright 2020 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "fcp/client/engine/plan_engine_helpers.h"
17
18 #include <functional>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/status/statusor.h"
24 #include "fcp/client/diag_codes.pb.h"
25 #include "fcp/client/opstats/opstats_logger.h"
26 #include "fcp/client/opstats/opstats_logger_impl.h"
27 // #include "fcp/client/opstats/pds_backed_opstats_db.h"
28 #include "fcp/protos/plan.pb.h"
29 #include "fcp/tensorflow/external_dataset.h"
30
31 namespace fcp {
32 namespace client {
33 namespace engine {
34 namespace {
35
36 using ::fcp::client::opstats::OpStatsLogger;
37 using ::fcp::client::opstats::OpStatsLoggerImpl;
38 // using ::fcp::client::opstats::PdsBackedOpStatsDb;
39 using ::google::internal::federated::plan::ExampleSelector;
40
41 /** An iterator that forwards the failing status from the external dataset to
42 * TensorFlow. */
43 class FailingDatasetIterator : public ExternalDatasetIterator {
44 public:
FailingDatasetIterator(absl::Status status)45 explicit FailingDatasetIterator(absl::Status status) : status_(status) {}
46
GetNext()47 absl::StatusOr<std::string> GetNext() final { return status_; }
48
49 private:
50 const absl::Status status_;
51 };
52
53 class TrainingDatasetProvider
54 : public ExternalDatasetProvider::UsingProtoSelector<ExampleSelector> {
55 public:
TrainingDatasetProvider(std::vector<ExampleIteratorFactory * > example_iterator_factories,OpStatsLogger * opstats_logger,std::atomic<int> * total_example_count,std::atomic<int64_t> * total_example_size_bytes,ExampleIteratorStatus * example_iterator_status)56 TrainingDatasetProvider(
57 std::vector<ExampleIteratorFactory*> example_iterator_factories,
58 OpStatsLogger* opstats_logger, std::atomic<int>* total_example_count,
59 std::atomic<int64_t>* total_example_size_bytes,
60 ExampleIteratorStatus* example_iterator_status)
61 : example_iterator_factories_(example_iterator_factories),
62 opstats_logger_(opstats_logger),
63 total_example_count_(total_example_count),
64 total_example_size_bytes_(total_example_size_bytes),
65 example_iterator_status_(example_iterator_status) {}
66
MakeDataset(ExampleSelector selector)67 absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
68 ExampleSelector selector) final {
69 return ExternalDataset::FromFunction(
70 [example_iterator_factories = example_iterator_factories_,
71 opstats_logger = opstats_logger_, selector,
72 total_example_count = total_example_count_,
73 total_example_size_bytes = total_example_size_bytes_,
74 example_iterator_status = example_iterator_status_]()
75 -> std::unique_ptr<ExternalDatasetIterator> {
76 ExampleIteratorFactory* example_iterator_factory =
77 FindExampleIteratorFactory(selector, example_iterator_factories);
78 // The DatasetOp requires a valid iterator at this stage so return an
79 // empty iterator if there was an error.
80 if (example_iterator_factory == nullptr) {
81 absl::Status error(
82 absl::StatusCode::kInternal,
83 "Could not find suitable ExampleIteratorFactory");
84 example_iterator_status->SetStatus(error);
85 return std::make_unique<FailingDatasetIterator>(error);
86 }
87 absl::StatusOr<std::unique_ptr<ExampleIterator>> example_iterator =
88 example_iterator_factory->CreateExampleIterator(selector);
89 if (!example_iterator.ok()) {
90 example_iterator_status->SetStatus(example_iterator.status());
91 return std::make_unique<FailingDatasetIterator>(
92 example_iterator.status());
93 }
94 return std::make_unique<DatasetIterator>(
95 std::move(*example_iterator), opstats_logger, total_example_count,
96 total_example_size_bytes, example_iterator_status,
97 selector.collection_uri(),
98 /*collect_stats=*/example_iterator_factory->ShouldCollectStats());
99 });
100 }
101
102 private:
103 std::vector<ExampleIteratorFactory*> example_iterator_factories_;
104 OpStatsLogger* opstats_logger_;
105 std::atomic<int>* total_example_count_;
106 std::atomic<int64_t>* total_example_size_bytes_;
107 ExampleIteratorStatus* example_iterator_status_;
108 };
109
110 } // namespace
111
DatasetIterator(std::unique_ptr<ExampleIterator> example_iterator,opstats::OpStatsLogger * opstats_logger,std::atomic<int> * total_example_count,std::atomic<int64_t> * total_example_size_bytes,ExampleIteratorStatus * example_iterator_status,const std::string & collection_uri,bool collect_stats)112 DatasetIterator::DatasetIterator(
113 std::unique_ptr<ExampleIterator> example_iterator,
114 opstats::OpStatsLogger* opstats_logger,
115 std::atomic<int>* total_example_count,
116 std::atomic<int64_t>* total_example_size_bytes,
117 ExampleIteratorStatus* example_iterator_status,
118 const std::string& collection_uri, bool collect_stats)
119 : example_iterator_(std::move(example_iterator)),
120 opstats_logger_(opstats_logger),
121 iterator_start_time_(absl::Now()),
122 total_example_count_(total_example_count),
123 total_example_size_bytes_(total_example_size_bytes),
124 example_iterator_status_(example_iterator_status),
125 example_count_(0),
126 example_size_bytes_(0),
127 collection_uri_(collection_uri),
128 iterator_finished_(false),
129 collect_stats_(collect_stats) {}
130
~DatasetIterator()131 DatasetIterator::~DatasetIterator() {
132 if (collect_stats_) {
133 opstats_logger_->UpdateDatasetStats(collection_uri_, example_count_,
134 example_size_bytes_);
135 }
136 }
137
138 // Returns the next entry from the dataset.
GetNext()139 absl::StatusOr<std::string> DatasetIterator::GetNext() {
140 absl::MutexLock locked(&iterator_lock_);
141 if (iterator_finished_) {
142 // If we've reached the end of the iterator, always return OUT_OF_RANGE.
143 return absl::OutOfRangeError("End of iterator reached");
144 }
145 absl::StatusOr<std::string> example = example_iterator_->Next();
146 absl::StatusCode error_code = example.status().code();
147 example_iterator_status_->SetStatus(example.status());
148 if (error_code == absl::StatusCode::kOutOfRange) {
149 example_iterator_->Close();
150 iterator_finished_ = true;
151 }
152 // If we're not forwarding an OUT_OF_RANGE to the caller, record example
153 // stats for metrics logging.
154 if (collect_stats_ && example.ok()) {
155 // TODO(team): Consider reducing logic duplication in
156 // cross-dataset and single-dataset example stat variables.
157 *total_example_count_ += 1;
158 *total_example_size_bytes_ += example->size();
159 example_count_ += 1;
160 example_size_bytes_ += example->size();
161 }
162 return example;
163 }
164
SetStatus(absl::Status status)165 void ExampleIteratorStatus::SetStatus(absl::Status status) {
166 absl::MutexLock lock(&mu_);
167 // We ignores normal status such as ok and outOfRange to avoid running into a
168 // race condition when an error happened, then an outofRange or ok status
169 // returned in a different thread which overrides the error status.
170 if (status.code() != absl::StatusCode::kOk &&
171 status.code() != absl::StatusCode::kOutOfRange) {
172 status_ = status;
173 }
174 }
175
GetStatus()176 absl::Status ExampleIteratorStatus::GetStatus() {
177 absl::MutexLock lock(&mu_);
178 return status_;
179 }
180
AddDatasetTokenToInputs(std::vector<ExampleIteratorFactory * > example_iterator_factories,OpStatsLogger * opstats_logger,std::vector<std::pair<std::string,tensorflow::Tensor>> * inputs,const std::string & dataset_token_tensor_name,std::atomic<int> * total_example_count,std::atomic<int64_t> * total_example_size_bytes,ExampleIteratorStatus * example_iterator_status)181 HostObjectRegistration AddDatasetTokenToInputs(
182 std::vector<ExampleIteratorFactory*> example_iterator_factories,
183 OpStatsLogger* opstats_logger,
184 std::vector<std::pair<std::string, tensorflow::Tensor>>* inputs,
185 const std::string& dataset_token_tensor_name,
186 std::atomic<int>* total_example_count,
187 std::atomic<int64_t>* total_example_size_bytes,
188 ExampleIteratorStatus* example_iterator_status) {
189 // Register the TrainingDatasetProvider with the global
190 // ExternalDatasetProviderRegistry.
191 auto host_registration = fcp::ExternalDatasetProviderRegistry::Register(
192 std::make_shared<TrainingDatasetProvider>(
193 example_iterator_factories, opstats_logger, total_example_count,
194 total_example_size_bytes, example_iterator_status));
195 // Pack the token returned from registering the provider into a string
196 // tensor. TensorFlow will use that token via the ExternalDatasetOp to create
197 // datasets and iterators.
198 tensorflow::Tensor token_scalar(std::string{});
199 token_scalar.scalar<tensorflow::tstring>()() =
200 host_registration.token().ToString();
201 std::pair<std::string, tensorflow::Tensor> token_pair(
202 dataset_token_tensor_name, token_scalar);
203 inputs->emplace_back(token_pair);
204 return host_registration;
205 }
206
AddDatasetTokenToInputsForTfLite(std::vector<ExampleIteratorFactory * > example_iterator_factories,OpStatsLogger * opstats_logger,absl::flat_hash_map<std::string,std::string> * inputs,const std::string & dataset_token_tensor_name,std::atomic<int> * total_example_count,std::atomic<int64_t> * total_example_size_bytes,ExampleIteratorStatus * example_iterator_status)207 HostObjectRegistration AddDatasetTokenToInputsForTfLite(
208 std::vector<ExampleIteratorFactory*> example_iterator_factories,
209 OpStatsLogger* opstats_logger,
210 absl::flat_hash_map<std::string, std::string>* inputs,
211 const std::string& dataset_token_tensor_name,
212 std::atomic<int>* total_example_count,
213 std::atomic<int64_t>* total_example_size_bytes,
214 ExampleIteratorStatus* example_iterator_status) {
215 // Registers the TrainingDatasetProvider with the global
216 // ExternalDatasetProviderRegistry.
217 auto host_registration = fcp::ExternalDatasetProviderRegistry::Register(
218 std::make_shared<TrainingDatasetProvider>(
219 example_iterator_factories, opstats_logger, total_example_count,
220 total_example_size_bytes, example_iterator_status));
221 // Adds the token returned from registering the provider to the map of inputs.
222 // TfLite will use that token via the ExternalDatasetOp to create
223 // datasets and iterators.
224 (*inputs)[dataset_token_tensor_name] = host_registration.token().ToString();
225 return host_registration;
226 }
227
CreateOpStatsLogger(const std::string & base_dir,const Flags * flags,LogManager * log_manager,const std::string & session_name,const std::string & population_name)228 std::unique_ptr<::fcp::client::opstats::OpStatsLogger> CreateOpStatsLogger(
229 const std::string& base_dir, const Flags* flags, LogManager* log_manager,
230 const std::string& session_name, const std::string& population_name) {
231 // if (flags->enable_opstats()) {
232 // auto db_or = PdsBackedOpStatsDb::Create(
233 // base_dir, flags->opstats_ttl_days() * absl::Hours(24), *log_manager,
234 // flags->opstats_db_size_limit_bytes());
235 // if (db_or.ok()) {
236 // return std::make_unique<OpStatsLoggerImpl>(
237 // std::move(db_or).value(), log_manager, flags, session_name,
238 // population_name);
239 // } else {
240 // if (flags->log_opstats_initialization_errors()) {
241 // return std::make_unique<OpStatsLogger>(
242 // /*opstats_enabled=*/flags->enable_opstats(),
243 // /*init_status=*/db_or.status());
244 // }
245 // }
246 // }
247 return std::make_unique<OpStatsLogger>(
248 /*opstats_enabled=*/flags->enable_opstats());
249 }
250
CreateComputationErrorPlanResult(absl::Status example_iterator_status,absl::Status computation_error_status)251 PlanResult CreateComputationErrorPlanResult(
252 absl::Status example_iterator_status,
253 absl::Status computation_error_status) {
254 switch (example_iterator_status.code()) {
255 case absl::StatusCode::kOk:
256 case absl::StatusCode::kOutOfRange:
257 // Either example iterators are working fine or we don't know the status
258 // of the example iterators. In this case, we'll use the error status
259 // returned from TensorFlow.
260 return PlanResult(PlanOutcome::kTensorflowError,
261 computation_error_status);
262 case absl::StatusCode::kCancelled:
263 // Example iterator got interrupted.
264 return PlanResult(PlanOutcome::kInterrupted, example_iterator_status);
265 default:
266 // All other Example iterator errors.
267 return PlanResult(PlanOutcome::kExampleIteratorError,
268 example_iterator_status);
269 }
270 }
271
FindExampleIteratorFactory(const ExampleSelector & selector,std::vector<ExampleIteratorFactory * > example_iterator_factories)272 ExampleIteratorFactory* FindExampleIteratorFactory(
273 const ExampleSelector& selector,
274 std::vector<ExampleIteratorFactory*> example_iterator_factories) {
275 for (ExampleIteratorFactory* factory : example_iterator_factories) {
276 if (factory->CanHandle(selector)) {
277 return factory;
278 }
279 }
280 return nullptr;
281 }
282
283 } // namespace engine
284 } // namespace client
285 } // namespace fcp
286