1 /*
2 * Copyright 2019 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef FCP_TENSORFLOW_EXTERNAL_DATASET_H_
18 #define FCP_TENSORFLOW_EXTERNAL_DATASET_H_
19
20 #include <memory>
21 #include <string>
22
23 #include "absl/status/status.h"
24 #include "absl/status/statusor.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/string_view.h"
27 #include "fcp/base/bounds.h"
28 #include "fcp/tensorflow/host_object.h"
29
30 namespace fcp {
31
32 /**
33 * Interface for an iterator, created from a particular dataset. A single
34 * dataset may be used to create multiple iterators.
35 */
36 class ExternalDatasetIterator {
37 public:
38 virtual ~ExternalDatasetIterator() = default;
39
40 /**
41 * Returns the next element, if possible. Indicates end-of-stream with
42 * OUT_OF_RANGE, even when repeatedly called. Corresponds to
43 * tensorflow::data::IteratorBase::GetNext.
44 *
45 * Implementations must be thread-safe.
46 */
47 virtual absl::StatusOr<std::string> GetNext() = 0;
48 };
49
50 namespace external_dataset_internal {
51
52 template <typename FuncType>
53 class DatasetFromFunction;
54
55 } // namespace external_dataset_internal
56
57 /**
58 * Interface for a particular dataset - created from an ExternalDatasetProvider
59 * (during dataset op execution), for a particular selector. A dataset may be
60 * used zero or more times to create an ExternalDatasetIterator.
61 *
62 * Dataset implementations are often trivial, just needing to capture some
63 * values (like the selector) for the iterator constructor. Consider using
64 * ExternalDataset::FromFunction.
65 */
66 class ExternalDataset {
67 public:
68 virtual ~ExternalDataset() = default;
69
70 /**
71 * Creates a new iterator. Corresponds to
72 * tensorflow::data::DatasetBase::MakeIterator.
73 */
74 virtual std::unique_ptr<ExternalDatasetIterator> MakeIterator() = 0;
75
76 /**
77 * Creates an ExternalDataset that wraps a callable object 'f', implementing
78 * MakeIterator(). The lifetime of 'f' is that of the dataset (so,
79 * by-reference lambda captures are almost always unsafe here).
80 */
81 template <typename F>
FromFunction(F f)82 static std::unique_ptr<ExternalDataset> FromFunction(F f) {
83 return std::make_unique<external_dataset_internal::DatasetFromFunction<F>>(
84 std::move(f));
85 }
86 };
87
88 /**
89 * Interface for an ExternalDataset op's host object.
90 *
91 * An ExternalDatasetProvider is a function from Selector -> ExternalDataset.
92 * Here, 'Selector' is a string provided to the dataset op (typically, an
93 * encoded proto). The returned ExternalDataset may be used (perhaps multiple
94 * times) to create an iterator.
95 *
96 * When implementing a dataset provider and the selector is a proto message,
97 * consider inheritng from ExternalDatasetProvider::UsingProtoSelector<T> (for
98 * some message type T).
99 */
100 class ExternalDatasetProvider {
101 public:
102 virtual ~ExternalDatasetProvider() = default;
103
104 /**
105 * Creates a dataset for a given selector.
106 *
107 * This function can usually be implemented succinctly, using
108 * ExternalDataset::FromFunction.
109 *
110 * Corresponds to tensorflow::data::DatasetOpKernel::MakeDataset.
111 */
112 virtual absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
113 absl::string_view selector) = 0;
114
115 /**
116 * Base class for dataset providers that expect a selector of a particular
117 * proto message type. If inheriting from UsingProtoSelector<T>, then one
118 * implements MakeDataset(T) instead of MakeDataset(absl::string_view).
119 */
120 template <typename T>
121 class UsingProtoSelector;
122 };
123
124 /**
125 * HostObjectRegistry for the ExternalDataset interface.
126 */
127 using ExternalDatasetProviderRegistry =
128 HostObjectRegistry<ExternalDatasetProvider>;
129
130 namespace external_dataset_internal {
131
132 template <typename T>
TryParseProtoSelector(absl::string_view selector)133 absl::StatusOr<T> TryParseProtoSelector(absl::string_view selector) {
134 T msg;
135 if (!msg.ParseFromArray(selector.data(),
136 CastIntegerChecked<int>(selector.size()))) {
137 return absl::InvalidArgumentError(absl::StrCat(
138 "Failed to parse selector proto of type ", msg.GetTypeName()));
139 }
140
141 return msg;
142 }
143
144 template <typename FuncType>
145 class DatasetFromFunction : public ExternalDataset {
146 public:
DatasetFromFunction(FuncType func)147 explicit DatasetFromFunction(FuncType func) : func_(std::move(func)) {}
148
MakeIterator()149 std::unique_ptr<ExternalDatasetIterator> MakeIterator() final {
150 return func_();
151 }
152
153 private:
154 FuncType func_;
155 };
156
157 } // namespace external_dataset_internal
158
159 template <typename T>
160 class ExternalDatasetProvider::UsingProtoSelector
161 : public ExternalDatasetProvider {
162 public:
MakeDataset(absl::string_view selector)163 absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
164 absl::string_view selector) final {
165 auto maybe_msg =
166 external_dataset_internal::TryParseProtoSelector<T>(selector);
167 if (!maybe_msg.ok()) {
168 return maybe_msg.status();
169 }
170
171 return MakeDataset(std::move(maybe_msg).value());
172 }
173
174 virtual absl::StatusOr<std::unique_ptr<ExternalDataset>> MakeDataset(
175 T selector) = 0;
176 };
177
178 } // namespace fcp
179
180 #endif // FCP_TENSORFLOW_EXTERNAL_DATASET_H_
181