1 /* 2 * Copyright 2019 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include <string> 18 #include <utility> 19 20 #include "absl/strings/str_cat.h" 21 #include "absl/strings/str_format.h" 22 #include "fcp/base/random_token.h" 23 #include "fcp/tensorflow/external_dataset.h" 24 #include "fcp/tensorflow/status.h" 25 #include "tensorflow/core/framework/common_shape_fns.h" 26 #include "tensorflow/core/framework/dataset.h" 27 #include "tensorflow/core/framework/op.h" 28 #include "tensorflow/core/framework/shape_inference.h" 29 #include "tensorflow/core/lib/core/errors.h" 30 #include "tensorflow/core/public/version.h" 31 32 namespace fcp { 33 34 /** 35 * ExternalDataset op-kernel. Delegates to an ExternalDatasetProvider, found 36 * from the ExternalDatasetProviderRegistry (a HostObjectRegistry). 37 * 38 * Inputs: 39 * selector: An opaque string scalar. Forwarded to the stub. 40 * token: String scalar. It should encode a token obtained from 41 * ExternalDatasetProviderRegistry::Register. 42 * 43 * See TensorFlow's guide to making custom dataset ops: 44 * https://www.tensorflow.org/guide/extend/formats 45 */ 46 class ExternalDatasetOp : public tensorflow::data::DatasetOpKernel { 47 public: 48 using tensorflow::data::DatasetOpKernel::DatasetOpKernel; 49 MakeDataset(tensorflow::OpKernelContext * ctx,tensorflow::data::DatasetBase ** output)50 void MakeDataset(tensorflow::OpKernelContext* ctx, 51 tensorflow::data::DatasetBase** output) override { 52 tensorflow::tstring token_str; 53 OP_REQUIRES_OK(ctx, 54 tensorflow::data::ParseScalarArgument<tensorflow::tstring>( 55 ctx, "token", &token_str)); 56 absl::Span<char const> token_bytes = token_str; 57 OP_REQUIRES(ctx, token_bytes.size() == kRandomTokenSizeInBytes, 58 tensorflow::errors::InvalidArgument(absl::StrFormat( 59 "Tokens have a fixed size. Expected: %d; Actual %d", 60 kRandomTokenSizeInBytes, token_bytes.size()))); 61 RandomToken token = RandomToken::FromBytes(token_bytes); 62 63 tensorflow::tstring selector_str; 64 OP_REQUIRES_OK(ctx, 65 tensorflow::data::ParseScalarArgument<tensorflow::tstring>( 66 ctx, "selector", &selector_str)); 67 68 std::optional<std::shared_ptr<ExternalDatasetProvider>> maybe_provider = 69 ExternalDatasetProviderRegistry::TryLookup(token); 70 OP_REQUIRES(ctx, maybe_provider.has_value(), 71 tensorflow::errors::InvalidArgument( 72 "A dataset provider is not currently registered for the " 73 "provided token: ", 74 token.ToPrintableString())); 75 76 std::shared_ptr<ExternalDatasetProvider> provider = 77 *std::move(maybe_provider); 78 StatusOr<std::unique_ptr<ExternalDataset>> maybe_dataset = 79 provider->MakeDataset(selector_str); 80 // The provider might not like the given selector. 81 if (!maybe_dataset.ok()) { 82 ctx->SetStatus(ConvertToTensorFlowStatus(maybe_dataset.status())); 83 return; 84 } 85 86 *output = new Dataset(ctx, std::move(maybe_dataset).value()); 87 } 88 89 private: 90 class Dataset : public tensorflow::data::DatasetBase { 91 public: Dataset(tensorflow::OpKernelContext * ctx,std::unique_ptr<ExternalDataset> stub)92 Dataset(tensorflow::OpKernelContext* ctx, 93 std::unique_ptr<ExternalDataset> stub) 94 : DatasetBase(tensorflow::data::DatasetContext(ctx)), 95 stub_(std::move(stub)) {} 96 MakeIteratorInternal(const std::string & prefix) const97 std::unique_ptr<tensorflow::data::IteratorBase> MakeIteratorInternal( 98 const std::string& prefix) const override { 99 std::unique_ptr<ExternalDatasetIterator> iter = stub_->MakeIterator(); 100 Iterator::Params params{ 101 this, tensorflow::strings::StrCat(prefix, "::ExternalDataset")}; 102 return std::unique_ptr<tensorflow::data::IteratorBase>( 103 new Iterator(params, std::move(iter))); 104 } 105 106 // Each iterator element is just a scalar string. 107 output_dtypes() const108 const tensorflow::DataTypeVector& output_dtypes() const override { 109 static auto* const dtypes = 110 new tensorflow::DataTypeVector({tensorflow::DT_STRING}); 111 return *dtypes; 112 } 113 output_shapes() const114 const std::vector<tensorflow::PartialTensorShape>& output_shapes() 115 const override { 116 static std::vector<tensorflow::PartialTensorShape>* shapes = 117 new std::vector<tensorflow::PartialTensorShape>({{}}); 118 return *shapes; 119 } 120 DebugString() const121 std::string DebugString() const override { 122 return "ExternalDatasetOp::Dataset"; 123 } 124 InputDatasets(std::vector<const DatasetBase * > * inputs) const125 tensorflow::Status InputDatasets( 126 std::vector<const DatasetBase*>* inputs) const override { 127 // ExternalDatast has no input datasets, so just return OK. 128 return tensorflow::OkStatus(); 129 } 130 131 // The `DatasetBase::CheckExternalState()` method was introduced on 8/7/2019. We 132 // use the `TF_GRAPH_DEF_VERSION` value (which is updated daily) to determine if 133 // we should add its override. 134 #if TF_GRAPH_DEF_VERSION > 125 CheckExternalState() const135 tensorflow::Status CheckExternalState() const override { 136 return tensorflow::OkStatus(); 137 } 138 #endif 139 140 protected: AsGraphDefInternal(tensorflow::data::SerializationContext * ctx,DatasetGraphDefBuilder * b,tensorflow::Node ** output) const141 tensorflow::Status AsGraphDefInternal( 142 tensorflow::data::SerializationContext* ctx, DatasetGraphDefBuilder* b, 143 tensorflow::Node** output) const override { 144 return ::tensorflow::errors::Unimplemented( 145 DebugString(), " does not support serialization."); 146 } 147 148 private: 149 class Iterator : public tensorflow::data::DatasetIterator<Dataset> { 150 public: Iterator(const Params & params,std::unique_ptr<ExternalDatasetIterator> stub)151 explicit Iterator(const Params& params, 152 std::unique_ptr<ExternalDatasetIterator> stub) 153 : DatasetIterator<Dataset>(params), stub_(std::move(stub)) {} 154 GetNextInternal(tensorflow::data::IteratorContext * ctx,std::vector<tensorflow::Tensor> * out_tensors,bool * end_of_sequence)155 tensorflow::Status GetNextInternal( 156 tensorflow::data::IteratorContext* ctx, 157 std::vector<tensorflow::Tensor>* out_tensors, 158 bool* end_of_sequence) override { 159 StatusOr<std::string> maybe_element; 160 { 161 absl::MutexLock _(&mu_); 162 maybe_element = stub_->GetNext(); 163 } 164 165 if (maybe_element.ok()) { 166 std::string element = std::move(maybe_element).value(); 167 168 // The {} at the end specifies a scalar tensor. 169 tensorflow::Tensor element_tensor(ctx->allocator({}), 170 tensorflow::DT_STRING, {}); 171 element_tensor.scalar<tensorflow::tstring>()() = element; 172 173 *end_of_sequence = false; 174 out_tensors->push_back(std::move(element_tensor)); 175 return tensorflow::OkStatus(); 176 } else { 177 *end_of_sequence = true; 178 if (maybe_element.status().code() == StatusCode::kOutOfRange) { 179 return tensorflow::OkStatus(); 180 } else { 181 return ConvertToTensorFlowStatus(maybe_element.status()); 182 } 183 } 184 } 185 186 protected: SaveInternal(tensorflow::data::SerializationContext * ctx,tensorflow::data::IteratorStateWriter * writer)187 tensorflow::Status SaveInternal( 188 // `::tensorflow::data::SerializationContext` argument was added on 189 // 2020-03-17 when `TF_GRAPH_DEF_VERSION` was defined to 343. 190 #if TF_GRAPH_DEF_VERSION > 343 191 tensorflow::data::SerializationContext* ctx, 192 #endif 193 tensorflow::data::IteratorStateWriter* writer) override { 194 return ::tensorflow::errors::Unimplemented( 195 "Save / Restore of an ExternalDataset iterator is not supported"); 196 } RestoreInternal(tensorflow::data::IteratorContext * ctx,tensorflow::data::IteratorStateReader * reader)197 tensorflow::Status RestoreInternal( 198 tensorflow::data::IteratorContext* ctx, 199 tensorflow::data::IteratorStateReader* reader) override { 200 return ::tensorflow::errors::Unimplemented( 201 "Save / Restore of an ExternalDataset iterator is not supported"); 202 } 203 204 private: 205 std::unique_ptr<ExternalDatasetIterator> stub_; 206 absl::Mutex mu_; 207 }; 208 209 // Private members of Dataset 210 211 std::unique_ptr<ExternalDataset> stub_; 212 }; 213 }; 214 215 REGISTER_OP("ExternalDataset") 216 .Input("token: string") 217 .Input("selector: string") 218 .Output("handle: variant") 219 .SetIsStateful() 220 .SetShapeFn(tensorflow::shape_inference::ScalarShape); 221 222 REGISTER_KERNEL_BUILDER(Name("ExternalDataset").Device(tensorflow::DEVICE_CPU), 223 ExternalDatasetOp); 224 225 } // namespace fcp 226