xref: /aosp_15_r20/external/federated-compute/fcp/client/engine/plan_engine_helpers.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2020 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "fcp/client/engine/plan_engine_helpers.h"
17 
18 #include <functional>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/status/statusor.h"
24 #include "fcp/client/diag_codes.pb.h"
25 #include "fcp/client/opstats/opstats_logger.h"
26 #include "fcp/client/opstats/opstats_logger_impl.h"
27 // #include "fcp/client/opstats/pds_backed_opstats_db.h"
28 #include "fcp/protos/plan.pb.h"
29 #include "fcp/tensorflow/external_dataset.h"
30 
31 namespace fcp {
32 namespace client {
33 namespace engine {
34 namespace {
35 
36 using ::fcp::client::opstats::OpStatsLogger;
37 using ::fcp::client::opstats::OpStatsLoggerImpl;
38 // using ::fcp::client::opstats::PdsBackedOpStatsDb;
39 using ::google::internal::federated::plan::ExampleSelector;
40 
41 /** An iterator that forwards the failing status from the external dataset to
42  * TensorFlow. */
43 class FailingDatasetIterator : public ExternalDatasetIterator {
44  public:
FailingDatasetIterator(absl::Status status)45   explicit FailingDatasetIterator(absl::Status status) : status_(status) {}
46 
GetNext()47   absl::StatusOr<std::string> GetNext() final { return status_; }
48 
49  private:
50   const absl::Status status_;
51 };
52 
53 class TrainingDatasetProvider
54     : public ExternalDatasetProvider::UsingProtoSelector<ExampleSelector> {
55  public:
TrainingDatasetProvider(std::vector<ExampleIteratorFactory * > example_iterator_factories,OpStatsLogger * opstats_logger,std::atomic<int> * total_example_count,std::atomic<int64_t> * total_example_size_bytes,ExampleIteratorStatus * example_iterator_status)56   TrainingDatasetProvider(
57       std::vector<ExampleIteratorFactory*> example_iterator_factories,
58       OpStatsLogger* opstats_logger, std::atomic<int>* total_example_count,
59       std::atomic<int64_t>* total_example_size_bytes,
60       ExampleIteratorStatus* example_iterator_status)
61       : example_iterator_factories_(example_iterator_factories),
62         opstats_logger_(opstats_logger),
63         total_example_count_(total_example_count),
64         total_example_size_bytes_(total_example_size_bytes),
65         example_iterator_status_(example_iterator_status) {}
66 
MakeDataset(ExampleSelector selector)67   absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
68       ExampleSelector selector) final {
69     return ExternalDataset::FromFunction(
70         [example_iterator_factories = example_iterator_factories_,
71          opstats_logger = opstats_logger_, selector,
72          total_example_count = total_example_count_,
73          total_example_size_bytes = total_example_size_bytes_,
74          example_iterator_status = example_iterator_status_]()
75             -> std::unique_ptr<ExternalDatasetIterator> {
76           ExampleIteratorFactory* example_iterator_factory =
77               FindExampleIteratorFactory(selector, example_iterator_factories);
78           // The DatasetOp requires a valid iterator at this stage so return an
79           // empty iterator if there was an error.
80           if (example_iterator_factory == nullptr) {
81             absl::Status error(
82                 absl::StatusCode::kInternal,
83                 "Could not find suitable ExampleIteratorFactory");
84             example_iterator_status->SetStatus(error);
85             return std::make_unique<FailingDatasetIterator>(error);
86           }
87           absl::StatusOr<std::unique_ptr<ExampleIterator>> example_iterator =
88               example_iterator_factory->CreateExampleIterator(selector);
89           if (!example_iterator.ok()) {
90             example_iterator_status->SetStatus(example_iterator.status());
91             return std::make_unique<FailingDatasetIterator>(
92                 example_iterator.status());
93           }
94           return std::make_unique<DatasetIterator>(
95               std::move(*example_iterator), opstats_logger, total_example_count,
96               total_example_size_bytes, example_iterator_status,
97               selector.collection_uri(),
98               /*collect_stats=*/example_iterator_factory->ShouldCollectStats());
99         });
100   }
101 
102  private:
103   std::vector<ExampleIteratorFactory*> example_iterator_factories_;
104   OpStatsLogger* opstats_logger_;
105   std::atomic<int>* total_example_count_;
106   std::atomic<int64_t>* total_example_size_bytes_;
107   ExampleIteratorStatus* example_iterator_status_;
108 };
109 
110 }  // namespace
111 
DatasetIterator(std::unique_ptr<ExampleIterator> example_iterator,opstats::OpStatsLogger * opstats_logger,std::atomic<int> * total_example_count,std::atomic<int64_t> * total_example_size_bytes,ExampleIteratorStatus * example_iterator_status,const std::string & collection_uri,bool collect_stats)112 DatasetIterator::DatasetIterator(
113     std::unique_ptr<ExampleIterator> example_iterator,
114     opstats::OpStatsLogger* opstats_logger,
115     std::atomic<int>* total_example_count,
116     std::atomic<int64_t>* total_example_size_bytes,
117     ExampleIteratorStatus* example_iterator_status,
118     const std::string& collection_uri, bool collect_stats)
119     : example_iterator_(std::move(example_iterator)),
120       opstats_logger_(opstats_logger),
121       iterator_start_time_(absl::Now()),
122       total_example_count_(total_example_count),
123       total_example_size_bytes_(total_example_size_bytes),
124       example_iterator_status_(example_iterator_status),
125       example_count_(0),
126       example_size_bytes_(0),
127       collection_uri_(collection_uri),
128       iterator_finished_(false),
129       collect_stats_(collect_stats) {}
130 
~DatasetIterator()131 DatasetIterator::~DatasetIterator() {
132   if (collect_stats_) {
133     opstats_logger_->UpdateDatasetStats(collection_uri_, example_count_,
134                                         example_size_bytes_);
135   }
136 }
137 
138 // Returns the next entry from the dataset.
GetNext()139 absl::StatusOr<std::string> DatasetIterator::GetNext() {
140   absl::MutexLock locked(&iterator_lock_);
141   if (iterator_finished_) {
142     // If we've reached the end of the iterator, always return OUT_OF_RANGE.
143     return absl::OutOfRangeError("End of iterator reached");
144   }
145   absl::StatusOr<std::string> example = example_iterator_->Next();
146   absl::StatusCode error_code = example.status().code();
147   example_iterator_status_->SetStatus(example.status());
148   if (error_code == absl::StatusCode::kOutOfRange) {
149     example_iterator_->Close();
150     iterator_finished_ = true;
151   }
152   // If we're not forwarding an OUT_OF_RANGE to the caller, record example
153   // stats for metrics logging.
154   if (collect_stats_ && example.ok()) {
155     // TODO(team): Consider reducing logic duplication in
156     // cross-dataset and single-dataset example stat variables.
157     *total_example_count_ += 1;
158     *total_example_size_bytes_ += example->size();
159     example_count_ += 1;
160     example_size_bytes_ += example->size();
161   }
162   return example;
163 }
164 
SetStatus(absl::Status status)165 void ExampleIteratorStatus::SetStatus(absl::Status status) {
166   absl::MutexLock lock(&mu_);
167   // We ignores normal status such as ok and outOfRange to avoid running into a
168   // race condition when an error happened, then an outofRange or ok status
169   // returned in a different thread which overrides the error status.
170   if (status.code() != absl::StatusCode::kOk &&
171       status.code() != absl::StatusCode::kOutOfRange) {
172     status_ = status;
173   }
174 }
175 
GetStatus()176 absl::Status ExampleIteratorStatus::GetStatus() {
177   absl::MutexLock lock(&mu_);
178   return status_;
179 }
180 
AddDatasetTokenToInputs(std::vector<ExampleIteratorFactory * > example_iterator_factories,OpStatsLogger * opstats_logger,std::vector<std::pair<std::string,tensorflow::Tensor>> * inputs,const std::string & dataset_token_tensor_name,std::atomic<int> * total_example_count,std::atomic<int64_t> * total_example_size_bytes,ExampleIteratorStatus * example_iterator_status)181 HostObjectRegistration AddDatasetTokenToInputs(
182     std::vector<ExampleIteratorFactory*> example_iterator_factories,
183     OpStatsLogger* opstats_logger,
184     std::vector<std::pair<std::string, tensorflow::Tensor>>* inputs,
185     const std::string& dataset_token_tensor_name,
186     std::atomic<int>* total_example_count,
187     std::atomic<int64_t>* total_example_size_bytes,
188     ExampleIteratorStatus* example_iterator_status) {
189   // Register the TrainingDatasetProvider with the global
190   // ExternalDatasetProviderRegistry.
191   auto host_registration = fcp::ExternalDatasetProviderRegistry::Register(
192       std::make_shared<TrainingDatasetProvider>(
193           example_iterator_factories, opstats_logger, total_example_count,
194           total_example_size_bytes, example_iterator_status));
195   // Pack the token returned from registering the provider into a string
196   // tensor. TensorFlow will use that token via the ExternalDatasetOp to create
197   // datasets and iterators.
198   tensorflow::Tensor token_scalar(std::string{});
199   token_scalar.scalar<tensorflow::tstring>()() =
200       host_registration.token().ToString();
201   std::pair<std::string, tensorflow::Tensor> token_pair(
202       dataset_token_tensor_name, token_scalar);
203   inputs->emplace_back(token_pair);
204   return host_registration;
205 }
206 
AddDatasetTokenToInputsForTfLite(std::vector<ExampleIteratorFactory * > example_iterator_factories,OpStatsLogger * opstats_logger,absl::flat_hash_map<std::string,std::string> * inputs,const std::string & dataset_token_tensor_name,std::atomic<int> * total_example_count,std::atomic<int64_t> * total_example_size_bytes,ExampleIteratorStatus * example_iterator_status)207 HostObjectRegistration AddDatasetTokenToInputsForTfLite(
208     std::vector<ExampleIteratorFactory*> example_iterator_factories,
209     OpStatsLogger* opstats_logger,
210     absl::flat_hash_map<std::string, std::string>* inputs,
211     const std::string& dataset_token_tensor_name,
212     std::atomic<int>* total_example_count,
213     std::atomic<int64_t>* total_example_size_bytes,
214     ExampleIteratorStatus* example_iterator_status) {
215   // Registers the TrainingDatasetProvider with the global
216   // ExternalDatasetProviderRegistry.
217   auto host_registration = fcp::ExternalDatasetProviderRegistry::Register(
218       std::make_shared<TrainingDatasetProvider>(
219           example_iterator_factories, opstats_logger, total_example_count,
220           total_example_size_bytes, example_iterator_status));
221   // Adds the token returned from registering the provider to the map of inputs.
222   // TfLite will use that token via the ExternalDatasetOp to create
223   // datasets and iterators.
224   (*inputs)[dataset_token_tensor_name] = host_registration.token().ToString();
225   return host_registration;
226 }
227 
CreateOpStatsLogger(const std::string & base_dir,const Flags * flags,LogManager * log_manager,const std::string & session_name,const std::string & population_name)228 std::unique_ptr<::fcp::client::opstats::OpStatsLogger> CreateOpStatsLogger(
229     const std::string& base_dir, const Flags* flags, LogManager* log_manager,
230     const std::string& session_name, const std::string& population_name) {
231   // if (flags->enable_opstats()) {
232   //   auto db_or = PdsBackedOpStatsDb::Create(
233   //       base_dir, flags->opstats_ttl_days() * absl::Hours(24), *log_manager,
234   //       flags->opstats_db_size_limit_bytes());
235   //   if (db_or.ok()) {
236   //       return std::make_unique<OpStatsLoggerImpl>(
237   //           std::move(db_or).value(), log_manager, flags, session_name,
238   //           population_name);
239   //   } else {
240   //       if (flags->log_opstats_initialization_errors()) {
241   //         return std::make_unique<OpStatsLogger>(
242   //             /*opstats_enabled=*/flags->enable_opstats(),
243   //             /*init_status=*/db_or.status());
244   //       }
245   //   }
246   // }
247   return std::make_unique<OpStatsLogger>(
248       /*opstats_enabled=*/flags->enable_opstats());
249 }
250 
CreateComputationErrorPlanResult(absl::Status example_iterator_status,absl::Status computation_error_status)251 PlanResult CreateComputationErrorPlanResult(
252     absl::Status example_iterator_status,
253     absl::Status computation_error_status) {
254   switch (example_iterator_status.code()) {
255     case absl::StatusCode::kOk:
256     case absl::StatusCode::kOutOfRange:
257       // Either example iterators are working fine or we don't know the status
258       // of the example iterators. In this case, we'll use the error status
259       // returned from TensorFlow.
260       return PlanResult(PlanOutcome::kTensorflowError,
261                         computation_error_status);
262     case absl::StatusCode::kCancelled:
263       // Example iterator got interrupted.
264       return PlanResult(PlanOutcome::kInterrupted, example_iterator_status);
265     default:
266       // All other Example iterator errors.
267       return PlanResult(PlanOutcome::kExampleIteratorError,
268                         example_iterator_status);
269   }
270 }
271 
FindExampleIteratorFactory(const ExampleSelector & selector,std::vector<ExampleIteratorFactory * > example_iterator_factories)272 ExampleIteratorFactory* FindExampleIteratorFactory(
273     const ExampleSelector& selector,
274     std::vector<ExampleIteratorFactory*> example_iterator_factories) {
275   for (ExampleIteratorFactory* factory : example_iterator_factories) {
276     if (factory->CanHandle(selector)) {
277       return factory;
278     }
279   }
280   return nullptr;
281 }
282 
283 }  // namespace engine
284 }  // namespace client
285 }  // namespace fcp
286