xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/external_dataset.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FCP_TENSORFLOW_EXTERNAL_DATASET_H_
18 #define FCP_TENSORFLOW_EXTERNAL_DATASET_H_
19 
20 #include <memory>
21 #include <string>
22 
23 #include "absl/status/status.h"
24 #include "absl/status/statusor.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/string_view.h"
27 #include "fcp/base/bounds.h"
28 #include "fcp/tensorflow/host_object.h"
29 
30 namespace fcp {
31 
32 /**
33  * Interface for an iterator, created from a particular dataset. A single
34  * dataset may be used to create multiple iterators.
35  */
36 class ExternalDatasetIterator {
37  public:
38   virtual ~ExternalDatasetIterator() = default;
39 
40   /**
41    * Returns the next element, if possible. Indicates end-of-stream with
42    * OUT_OF_RANGE, even when repeatedly called. Corresponds to
43    * tensorflow::data::IteratorBase::GetNext.
44    *
45    * Implementations must be thread-safe.
46    */
47   virtual absl::StatusOr<std::string> GetNext() = 0;
48 };
49 
50 namespace external_dataset_internal {
51 
52 template <typename FuncType>
53 class DatasetFromFunction;
54 
55 }  // namespace external_dataset_internal
56 
57 /**
58  * Interface for a particular dataset - created from an ExternalDatasetProvider
59  * (during dataset op execution), for a particular selector. A dataset may be
60  * used zero or more times to create an ExternalDatasetIterator.
61  *
62  * Dataset implementations are often trivial, just needing to capture some
63  * values (like the selector) for the iterator constructor. Consider using
64  * ExternalDataset::FromFunction.
65  */
66 class ExternalDataset {
67  public:
68   virtual ~ExternalDataset() = default;
69 
70   /**
71    * Creates a new iterator. Corresponds to
72    * tensorflow::data::DatasetBase::MakeIterator.
73    */
74   virtual std::unique_ptr<ExternalDatasetIterator> MakeIterator() = 0;
75 
76   /**
77    * Creates an ExternalDataset that wraps a callable object 'f', implementing
78    * MakeIterator(). The lifetime of 'f' is that of the dataset (so,
79    * by-reference lambda captures are almost always unsafe here).
80    */
81   template <typename F>
FromFunction(F f)82   static std::unique_ptr<ExternalDataset> FromFunction(F f) {
83     return std::make_unique<external_dataset_internal::DatasetFromFunction<F>>(
84         std::move(f));
85   }
86 };
87 
88 /**
89  * Interface for an ExternalDataset op's host object.
90  *
91  * An ExternalDatasetProvider is a function from Selector -> ExternalDataset.
92  * Here, 'Selector' is a string provided to the dataset op (typically, an
93  * encoded proto). The returned ExternalDataset may be used (perhaps multiple
94  * times) to create an iterator.
95  *
96  * When implementing a dataset provider and the selector is a proto message,
97  * consider inheritng from ExternalDatasetProvider::UsingProtoSelector<T> (for
98  * some message type T).
99  */
100 class ExternalDatasetProvider {
101  public:
102   virtual ~ExternalDatasetProvider() = default;
103 
104   /**
105    * Creates a dataset for a given selector.
106    *
107    * This function can usually be implemented succinctly, using
108    * ExternalDataset::FromFunction.
109    *
110    * Corresponds to tensorflow::data::DatasetOpKernel::MakeDataset.
111    */
112   virtual absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
113       absl::string_view selector) = 0;
114 
115   /**
116    * Base class for dataset providers that expect a selector of a particular
117    * proto message type. If inheriting from UsingProtoSelector<T>, then one
118    * implements MakeDataset(T) instead of MakeDataset(absl::string_view).
119    */
120   template <typename T>
121   class UsingProtoSelector;
122 };
123 
124 /**
125  * HostObjectRegistry for the ExternalDataset interface.
126  */
127 using ExternalDatasetProviderRegistry =
128     HostObjectRegistry<ExternalDatasetProvider>;
129 
130 namespace external_dataset_internal {
131 
132 template <typename T>
TryParseProtoSelector(absl::string_view selector)133 absl::StatusOr<T> TryParseProtoSelector(absl::string_view selector) {
134   T msg;
135   if (!msg.ParseFromArray(selector.data(),
136                           CastIntegerChecked<int>(selector.size()))) {
137     return absl::InvalidArgumentError(absl::StrCat(
138         "Failed to parse selector proto of type ", msg.GetTypeName()));
139   }
140 
141   return msg;
142 }
143 
144 template <typename FuncType>
145 class DatasetFromFunction : public ExternalDataset {
146  public:
DatasetFromFunction(FuncType func)147   explicit DatasetFromFunction(FuncType func) : func_(std::move(func)) {}
148 
MakeIterator()149   std::unique_ptr<ExternalDatasetIterator> MakeIterator() final {
150     return func_();
151   }
152 
153  private:
154   FuncType func_;
155 };
156 
157 }  // namespace external_dataset_internal
158 
159 template <typename T>
160 class ExternalDatasetProvider::UsingProtoSelector
161     : public ExternalDatasetProvider {
162  public:
MakeDataset(absl::string_view selector)163   absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
164       absl::string_view selector) final {
165     auto maybe_msg =
166         external_dataset_internal::TryParseProtoSelector<T>(selector);
167     if (!maybe_msg.ok()) {
168       return maybe_msg.status();
169     }
170 
171     return MakeDataset(std::move(maybe_msg).value());
172   }
173 
174   virtual absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
175       T selector) = 0;
176 };
177 
178 }  // namespace fcp
179 
180 #endif  // FCP_TENSORFLOW_EXTERNAL_DATASET_H_
181