xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/external_dataset.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2019 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker 
17*14675a02SAndroid Build Coastguard Worker #ifndef FCP_TENSORFLOW_EXTERNAL_DATASET_H_
18*14675a02SAndroid Build Coastguard Worker #define FCP_TENSORFLOW_EXTERNAL_DATASET_H_
19*14675a02SAndroid Build Coastguard Worker 
20*14675a02SAndroid Build Coastguard Worker #include <memory>
21*14675a02SAndroid Build Coastguard Worker #include <string>
22*14675a02SAndroid Build Coastguard Worker 
23*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h"
24*14675a02SAndroid Build Coastguard Worker #include "absl/status/statusor.h"
25*14675a02SAndroid Build Coastguard Worker #include "absl/strings/str_cat.h"
26*14675a02SAndroid Build Coastguard Worker #include "absl/strings/string_view.h"
27*14675a02SAndroid Build Coastguard Worker #include "fcp/base/bounds.h"
28*14675a02SAndroid Build Coastguard Worker #include "fcp/tensorflow/host_object.h"
29*14675a02SAndroid Build Coastguard Worker 
30*14675a02SAndroid Build Coastguard Worker namespace fcp {
31*14675a02SAndroid Build Coastguard Worker 
32*14675a02SAndroid Build Coastguard Worker /**
33*14675a02SAndroid Build Coastguard Worker  * Interface for an iterator, created from a particular dataset. A single
34*14675a02SAndroid Build Coastguard Worker  * dataset may be used to create multiple iterators.
35*14675a02SAndroid Build Coastguard Worker  */
36*14675a02SAndroid Build Coastguard Worker class ExternalDatasetIterator {
37*14675a02SAndroid Build Coastguard Worker  public:
38*14675a02SAndroid Build Coastguard Worker   virtual ~ExternalDatasetIterator() = default;
39*14675a02SAndroid Build Coastguard Worker 
40*14675a02SAndroid Build Coastguard Worker   /**
41*14675a02SAndroid Build Coastguard Worker    * Returns the next element, if possible. Indicates end-of-stream with
42*14675a02SAndroid Build Coastguard Worker    * OUT_OF_RANGE, even when repeatedly called. Corresponds to
43*14675a02SAndroid Build Coastguard Worker    * tensorflow::data::IteratorBase::GetNext.
44*14675a02SAndroid Build Coastguard Worker    *
45*14675a02SAndroid Build Coastguard Worker    * Implementations must be thread-safe.
46*14675a02SAndroid Build Coastguard Worker    */
47*14675a02SAndroid Build Coastguard Worker   virtual absl::StatusOr<std::string> GetNext() = 0;
48*14675a02SAndroid Build Coastguard Worker };
49*14675a02SAndroid Build Coastguard Worker 
50*14675a02SAndroid Build Coastguard Worker namespace external_dataset_internal {
51*14675a02SAndroid Build Coastguard Worker 
52*14675a02SAndroid Build Coastguard Worker template <typename FuncType>
53*14675a02SAndroid Build Coastguard Worker class DatasetFromFunction;
54*14675a02SAndroid Build Coastguard Worker 
55*14675a02SAndroid Build Coastguard Worker }  // namespace external_dataset_internal
56*14675a02SAndroid Build Coastguard Worker 
57*14675a02SAndroid Build Coastguard Worker /**
58*14675a02SAndroid Build Coastguard Worker  * Interface for a particular dataset - created from an ExternalDatasetProvider
59*14675a02SAndroid Build Coastguard Worker  * (during dataset op execution), for a particular selector. A dataset may be
60*14675a02SAndroid Build Coastguard Worker  * used zero or more times to create an ExternalDatasetIterator.
61*14675a02SAndroid Build Coastguard Worker  *
62*14675a02SAndroid Build Coastguard Worker  * Dataset implementations are often trivial, just needing to capture some
63*14675a02SAndroid Build Coastguard Worker  * values (like the selector) for the iterator constructor. Consider using
64*14675a02SAndroid Build Coastguard Worker  * ExternalDataset::FromFunction.
65*14675a02SAndroid Build Coastguard Worker  */
66*14675a02SAndroid Build Coastguard Worker class ExternalDataset {
67*14675a02SAndroid Build Coastguard Worker  public:
68*14675a02SAndroid Build Coastguard Worker   virtual ~ExternalDataset() = default;
69*14675a02SAndroid Build Coastguard Worker 
70*14675a02SAndroid Build Coastguard Worker   /**
71*14675a02SAndroid Build Coastguard Worker    * Creates a new iterator. Corresponds to
72*14675a02SAndroid Build Coastguard Worker    * tensorflow::data::DatasetBase::MakeIterator.
73*14675a02SAndroid Build Coastguard Worker    */
74*14675a02SAndroid Build Coastguard Worker   virtual std::unique_ptr<ExternalDatasetIterator> MakeIterator() = 0;
75*14675a02SAndroid Build Coastguard Worker 
76*14675a02SAndroid Build Coastguard Worker   /**
77*14675a02SAndroid Build Coastguard Worker    * Creates an ExternalDataset that wraps a callable object 'f', implementing
78*14675a02SAndroid Build Coastguard Worker    * MakeIterator(). The lifetime of 'f' is that of the dataset (so,
79*14675a02SAndroid Build Coastguard Worker    * by-reference lambda captures are almost always unsafe here).
80*14675a02SAndroid Build Coastguard Worker    */
81*14675a02SAndroid Build Coastguard Worker   template <typename F>
FromFunction(F f)82*14675a02SAndroid Build Coastguard Worker   static std::unique_ptr<ExternalDataset> FromFunction(F f) {
83*14675a02SAndroid Build Coastguard Worker     return std::make_unique<external_dataset_internal::DatasetFromFunction<F>>(
84*14675a02SAndroid Build Coastguard Worker         std::move(f));
85*14675a02SAndroid Build Coastguard Worker   }
86*14675a02SAndroid Build Coastguard Worker };
87*14675a02SAndroid Build Coastguard Worker 
88*14675a02SAndroid Build Coastguard Worker /**
89*14675a02SAndroid Build Coastguard Worker  * Interface for an ExternalDataset op's host object.
90*14675a02SAndroid Build Coastguard Worker  *
91*14675a02SAndroid Build Coastguard Worker  * An ExternalDatasetProvider is a function from Selector -> ExternalDataset.
92*14675a02SAndroid Build Coastguard Worker  * Here, 'Selector' is a string provided to the dataset op (typically, an
93*14675a02SAndroid Build Coastguard Worker  * encoded proto). The returned ExternalDataset may be used (perhaps multiple
94*14675a02SAndroid Build Coastguard Worker  * times) to create an iterator.
95*14675a02SAndroid Build Coastguard Worker  *
96*14675a02SAndroid Build Coastguard Worker  * When implementing a dataset provider and the selector is a proto message,
97*14675a02SAndroid Build Coastguard Worker  * consider inheritng from ExternalDatasetProvider::UsingProtoSelector<T> (for
98*14675a02SAndroid Build Coastguard Worker  * some message type T).
99*14675a02SAndroid Build Coastguard Worker  */
100*14675a02SAndroid Build Coastguard Worker class ExternalDatasetProvider {
101*14675a02SAndroid Build Coastguard Worker  public:
102*14675a02SAndroid Build Coastguard Worker   virtual ~ExternalDatasetProvider() = default;
103*14675a02SAndroid Build Coastguard Worker 
104*14675a02SAndroid Build Coastguard Worker   /**
105*14675a02SAndroid Build Coastguard Worker    * Creates a dataset for a given selector.
106*14675a02SAndroid Build Coastguard Worker    *
107*14675a02SAndroid Build Coastguard Worker    * This function can usually be implemented succinctly, using
108*14675a02SAndroid Build Coastguard Worker    * ExternalDataset::FromFunction.
109*14675a02SAndroid Build Coastguard Worker    *
110*14675a02SAndroid Build Coastguard Worker    * Corresponds to tensorflow::data::DatasetOpKernel::MakeDataset.
111*14675a02SAndroid Build Coastguard Worker    */
112*14675a02SAndroid Build Coastguard Worker   virtual absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
113*14675a02SAndroid Build Coastguard Worker       absl::string_view selector) = 0;
114*14675a02SAndroid Build Coastguard Worker 
115*14675a02SAndroid Build Coastguard Worker   /**
116*14675a02SAndroid Build Coastguard Worker    * Base class for dataset providers that expect a selector of a particular
117*14675a02SAndroid Build Coastguard Worker    * proto message type. If inheriting from UsingProtoSelector<T>, then one
118*14675a02SAndroid Build Coastguard Worker    * implements MakeDataset(T) instead of MakeDataset(absl::string_view).
119*14675a02SAndroid Build Coastguard Worker    */
120*14675a02SAndroid Build Coastguard Worker   template <typename T>
121*14675a02SAndroid Build Coastguard Worker   class UsingProtoSelector;
122*14675a02SAndroid Build Coastguard Worker };
123*14675a02SAndroid Build Coastguard Worker 
124*14675a02SAndroid Build Coastguard Worker /**
125*14675a02SAndroid Build Coastguard Worker  * HostObjectRegistry for the ExternalDataset interface.
126*14675a02SAndroid Build Coastguard Worker  */
127*14675a02SAndroid Build Coastguard Worker using ExternalDatasetProviderRegistry =
128*14675a02SAndroid Build Coastguard Worker     HostObjectRegistry<ExternalDatasetProvider>;
129*14675a02SAndroid Build Coastguard Worker 
130*14675a02SAndroid Build Coastguard Worker namespace external_dataset_internal {
131*14675a02SAndroid Build Coastguard Worker 
132*14675a02SAndroid Build Coastguard Worker template <typename T>
TryParseProtoSelector(absl::string_view selector)133*14675a02SAndroid Build Coastguard Worker absl::StatusOr<T> TryParseProtoSelector(absl::string_view selector) {
134*14675a02SAndroid Build Coastguard Worker   T msg;
135*14675a02SAndroid Build Coastguard Worker   if (!msg.ParseFromArray(selector.data(),
136*14675a02SAndroid Build Coastguard Worker                           CastIntegerChecked<int>(selector.size()))) {
137*14675a02SAndroid Build Coastguard Worker     return absl::InvalidArgumentError(absl::StrCat(
138*14675a02SAndroid Build Coastguard Worker         "Failed to parse selector proto of type ", msg.GetTypeName()));
139*14675a02SAndroid Build Coastguard Worker   }
140*14675a02SAndroid Build Coastguard Worker 
141*14675a02SAndroid Build Coastguard Worker   return msg;
142*14675a02SAndroid Build Coastguard Worker }
143*14675a02SAndroid Build Coastguard Worker 
144*14675a02SAndroid Build Coastguard Worker template <typename FuncType>
145*14675a02SAndroid Build Coastguard Worker class DatasetFromFunction : public ExternalDataset {
146*14675a02SAndroid Build Coastguard Worker  public:
DatasetFromFunction(FuncType func)147*14675a02SAndroid Build Coastguard Worker   explicit DatasetFromFunction(FuncType func) : func_(std::move(func)) {}
148*14675a02SAndroid Build Coastguard Worker 
MakeIterator()149*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<ExternalDatasetIterator> MakeIterator() final {
150*14675a02SAndroid Build Coastguard Worker     return func_();
151*14675a02SAndroid Build Coastguard Worker   }
152*14675a02SAndroid Build Coastguard Worker 
153*14675a02SAndroid Build Coastguard Worker  private:
154*14675a02SAndroid Build Coastguard Worker   FuncType func_;
155*14675a02SAndroid Build Coastguard Worker };
156*14675a02SAndroid Build Coastguard Worker 
157*14675a02SAndroid Build Coastguard Worker }  // namespace external_dataset_internal
158*14675a02SAndroid Build Coastguard Worker 
159*14675a02SAndroid Build Coastguard Worker template <typename T>
160*14675a02SAndroid Build Coastguard Worker class ExternalDatasetProvider::UsingProtoSelector
161*14675a02SAndroid Build Coastguard Worker     : public ExternalDatasetProvider {
162*14675a02SAndroid Build Coastguard Worker  public:
MakeDataset(absl::string_view selector)163*14675a02SAndroid Build Coastguard Worker   absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
164*14675a02SAndroid Build Coastguard Worker       absl::string_view selector) final {
165*14675a02SAndroid Build Coastguard Worker     auto maybe_msg =
166*14675a02SAndroid Build Coastguard Worker         external_dataset_internal::TryParseProtoSelector<T>(selector);
167*14675a02SAndroid Build Coastguard Worker     if (!maybe_msg.ok()) {
168*14675a02SAndroid Build Coastguard Worker       return maybe_msg.status();
169*14675a02SAndroid Build Coastguard Worker     }
170*14675a02SAndroid Build Coastguard Worker 
171*14675a02SAndroid Build Coastguard Worker     return MakeDataset(std::move(maybe_msg).value());
172*14675a02SAndroid Build Coastguard Worker   }
173*14675a02SAndroid Build Coastguard Worker 
174*14675a02SAndroid Build Coastguard Worker   virtual absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
175*14675a02SAndroid Build Coastguard Worker       T selector) = 0;
176*14675a02SAndroid Build Coastguard Worker };
177*14675a02SAndroid Build Coastguard Worker 
178*14675a02SAndroid Build Coastguard Worker }  // namespace fcp
179*14675a02SAndroid Build Coastguard Worker 
180*14675a02SAndroid Build Coastguard Worker #endif  // FCP_TENSORFLOW_EXTERNAL_DATASET_H_
181