/* * Copyright 2019 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef FCP_TENSORFLOW_EXTERNAL_DATASET_H_ #define FCP_TENSORFLOW_EXTERNAL_DATASET_H_ #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "fcp/base/bounds.h" #include "fcp/tensorflow/host_object.h" namespace fcp { /** * Interface for an iterator, created from a particular dataset. A single * dataset may be used to create multiple iterators. */ class ExternalDatasetIterator { public: virtual ~ExternalDatasetIterator() = default; /** * Returns the next element, if possible. Indicates end-of-stream with * OUT_OF_RANGE, even when repeatedly called. Corresponds to * tensorflow::data::IteratorBase::GetNext. * * Implementations must be thread-safe. */ virtual absl::StatusOr GetNext() = 0; }; namespace external_dataset_internal { template class DatasetFromFunction; } // namespace external_dataset_internal /** * Interface for a particular dataset - created from an ExternalDatasetProvider * (during dataset op execution), for a particular selector. A dataset may be * used zero or more times to create an ExternalDatasetIterator. * * Dataset implementations are often trivial, just needing to capture some * values (like the selector) for the iterator constructor. Consider using * ExternalDataset::FromFunction. */ class ExternalDataset { public: virtual ~ExternalDataset() = default; /** * Creates a new iterator. Corresponds to * tensorflow::data::DatasetBase::MakeIterator. */ virtual std::unique_ptr MakeIterator() = 0; /** * Creates an ExternalDataset that wraps a callable object 'f', implementing * MakeIterator(). The lifetime of 'f' is that of the dataset (so, * by-reference lambda captures are almost always unsafe here). */ template static std::unique_ptr FromFunction(F f) { return std::make_unique>( std::move(f)); } }; /** * Interface for an ExternalDataset op's host object. * * An ExternalDatasetProvider is a function from Selector -> ExternalDataset. * Here, 'Selector' is a string provided to the dataset op (typically, an * encoded proto). The returned ExternalDataset may be used (perhaps multiple * times) to create an iterator. * * When implementing a dataset provider and the selector is a proto message, * consider inheritng from ExternalDatasetProvider::UsingProtoSelector (for * some message type T). */ class ExternalDatasetProvider { public: virtual ~ExternalDatasetProvider() = default; /** * Creates a dataset for a given selector. * * This function can usually be implemented succinctly, using * ExternalDataset::FromFunction. * * Corresponds to tensorflow::data::DatasetOpKernel::MakeDataset. */ virtual absl::StatusOr> MakeDataset( absl::string_view selector) = 0; /** * Base class for dataset providers that expect a selector of a particular * proto message type. If inheriting from UsingProtoSelector, then one * implements MakeDataset(T) instead of MakeDataset(absl::string_view). */ template class UsingProtoSelector; }; /** * HostObjectRegistry for the ExternalDataset interface. */ using ExternalDatasetProviderRegistry = HostObjectRegistry; namespace external_dataset_internal { template absl::StatusOr TryParseProtoSelector(absl::string_view selector) { T msg; if (!msg.ParseFromArray(selector.data(), CastIntegerChecked(selector.size()))) { return absl::InvalidArgumentError(absl::StrCat( "Failed to parse selector proto of type ", msg.GetTypeName())); } return msg; } template class DatasetFromFunction : public ExternalDataset { public: explicit DatasetFromFunction(FuncType func) : func_(std::move(func)) {} std::unique_ptr MakeIterator() final { return func_(); } private: FuncType func_; }; } // namespace external_dataset_internal template class ExternalDatasetProvider::UsingProtoSelector : public ExternalDatasetProvider { public: absl::StatusOr> MakeDataset( absl::string_view selector) final { auto maybe_msg = external_dataset_internal::TryParseProtoSelector(selector); if (!maybe_msg.ok()) { return maybe_msg.status(); } return MakeDataset(std::move(maybe_msg).value()); } virtual absl::StatusOr> MakeDataset( T selector) = 0; }; } // namespace fcp #endif // FCP_TENSORFLOW_EXTERNAL_DATASET_H_