xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/external_dataset_op.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <string>
18 #include <utility>
19 
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_format.h"
22 #include "fcp/base/random_token.h"
23 #include "fcp/tensorflow/external_dataset.h"
24 #include "fcp/tensorflow/status.h"
25 #include "tensorflow/core/framework/common_shape_fns.h"
26 #include "tensorflow/core/framework/dataset.h"
27 #include "tensorflow/core/framework/op.h"
28 #include "tensorflow/core/framework/shape_inference.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/public/version.h"
31 
32 namespace fcp {
33 
34 /**
35  * ExternalDataset op-kernel. Delegates to an ExternalDatasetProvider, found
36  * from the ExternalDatasetProviderRegistry (a HostObjectRegistry).
37  *
38  * Inputs:
39  *   selector: An opaque string scalar. Forwarded to the stub.
40  *   token: String scalar. It should encode a token obtained from
41  *          ExternalDatasetProviderRegistry::Register.
42  *
43  * See TensorFlow's guide to making custom dataset ops:
44  * https://www.tensorflow.org/guide/extend/formats
45  */
46 class ExternalDatasetOp : public tensorflow::data::DatasetOpKernel {
47  public:
48   using tensorflow::data::DatasetOpKernel::DatasetOpKernel;
49 
MakeDataset(tensorflow::OpKernelContext * ctx,tensorflow::data::DatasetBase ** output)50   void MakeDataset(tensorflow::OpKernelContext* ctx,
51                    tensorflow::data::DatasetBase** output) override {
52     tensorflow::tstring token_str;
53     OP_REQUIRES_OK(ctx,
54                    tensorflow::data::ParseScalarArgument<tensorflow::tstring>(
55                        ctx, "token", &token_str));
56     absl::Span<char const> token_bytes = token_str;
57     OP_REQUIRES(ctx, token_bytes.size() == kRandomTokenSizeInBytes,
58                 tensorflow::errors::InvalidArgument(absl::StrFormat(
59                     "Tokens have a fixed size. Expected: %d; Actual %d",
60                     kRandomTokenSizeInBytes, token_bytes.size())));
61     RandomToken token = RandomToken::FromBytes(token_bytes);
62 
63     tensorflow::tstring selector_str;
64     OP_REQUIRES_OK(ctx,
65                    tensorflow::data::ParseScalarArgument<tensorflow::tstring>(
66                        ctx, "selector", &selector_str));
67 
68     std::optional<std::shared_ptr<ExternalDatasetProvider>> maybe_provider =
69         ExternalDatasetProviderRegistry::TryLookup(token);
70     OP_REQUIRES(ctx, maybe_provider.has_value(),
71                 tensorflow::errors::InvalidArgument(
72                     "A dataset provider is not currently registered for the "
73                     "provided token: ",
74                     token.ToPrintableString()));
75 
76     std::shared_ptr<ExternalDatasetProvider> provider =
77         *std::move(maybe_provider);
78     StatusOr<std::unique_ptr<ExternalDataset>> maybe_dataset =
79         provider->MakeDataset(selector_str);
80     // The provider might not like the given selector.
81     if (!maybe_dataset.ok()) {
82       ctx->SetStatus(ConvertToTensorFlowStatus(maybe_dataset.status()));
83       return;
84     }
85 
86     *output = new Dataset(ctx, std::move(maybe_dataset).value());
87   }
88 
89  private:
90   class Dataset : public tensorflow::data::DatasetBase {
91    public:
Dataset(tensorflow::OpKernelContext * ctx,std::unique_ptr<ExternalDataset> stub)92     Dataset(tensorflow::OpKernelContext* ctx,
93             std::unique_ptr<ExternalDataset> stub)
94         : DatasetBase(tensorflow::data::DatasetContext(ctx)),
95           stub_(std::move(stub)) {}
96 
MakeIteratorInternal(const std::string & prefix) const97     std::unique_ptr<tensorflow::data::IteratorBase> MakeIteratorInternal(
98         const std::string& prefix) const override {
99       std::unique_ptr<ExternalDatasetIterator> iter = stub_->MakeIterator();
100       Iterator::Params params{
101           this, tensorflow::strings::StrCat(prefix, "::ExternalDataset")};
102       return std::unique_ptr<tensorflow::data::IteratorBase>(
103           new Iterator(params, std::move(iter)));
104     }
105 
106     // Each iterator element is just a scalar string.
107 
output_dtypes() const108     const tensorflow::DataTypeVector& output_dtypes() const override {
109       static auto* const dtypes =
110           new tensorflow::DataTypeVector({tensorflow::DT_STRING});
111       return *dtypes;
112     }
113 
output_shapes() const114     const std::vector<tensorflow::PartialTensorShape>& output_shapes()
115         const override {
116       static std::vector<tensorflow::PartialTensorShape>* shapes =
117           new std::vector<tensorflow::PartialTensorShape>({{}});
118       return *shapes;
119     }
120 
DebugString() const121     std::string DebugString() const override {
122       return "ExternalDatasetOp::Dataset";
123     }
124 
InputDatasets(std::vector<const DatasetBase * > * inputs) const125     tensorflow::Status InputDatasets(
126         std::vector<const DatasetBase*>* inputs) const override {
127       // ExternalDatast has no input datasets, so just return OK.
128       return tensorflow::OkStatus();
129     }
130 
131 // The `DatasetBase::CheckExternalState()` method was introduced on 8/7/2019. We
132 // use the `TF_GRAPH_DEF_VERSION` value (which is updated daily) to determine if
133 // we should add its override.
134 #if TF_GRAPH_DEF_VERSION > 125
CheckExternalState() const135     tensorflow::Status CheckExternalState() const override {
136       return tensorflow::OkStatus();
137     }
138 #endif
139 
140    protected:
AsGraphDefInternal(tensorflow::data::SerializationContext * ctx,DatasetGraphDefBuilder * b,tensorflow::Node ** output) const141     tensorflow::Status AsGraphDefInternal(
142         tensorflow::data::SerializationContext* ctx, DatasetGraphDefBuilder* b,
143         tensorflow::Node** output) const override {
144       return ::tensorflow::errors::Unimplemented(
145           DebugString(), " does not support serialization.");
146     }
147 
148    private:
149     class Iterator : public tensorflow::data::DatasetIterator<Dataset> {
150      public:
Iterator(const Params & params,std::unique_ptr<ExternalDatasetIterator> stub)151       explicit Iterator(const Params& params,
152                         std::unique_ptr<ExternalDatasetIterator> stub)
153           : DatasetIterator<Dataset>(params), stub_(std::move(stub)) {}
154 
GetNextInternal(tensorflow::data::IteratorContext * ctx,std::vector<tensorflow::Tensor> * out_tensors,bool * end_of_sequence)155       tensorflow::Status GetNextInternal(
156           tensorflow::data::IteratorContext* ctx,
157           std::vector<tensorflow::Tensor>* out_tensors,
158           bool* end_of_sequence) override {
159         StatusOr<std::string> maybe_element;
160         {
161           absl::MutexLock _(&mu_);
162           maybe_element = stub_->GetNext();
163         }
164 
165         if (maybe_element.ok()) {
166           std::string element = std::move(maybe_element).value();
167 
168           // The {} at the end specifies a scalar tensor.
169           tensorflow::Tensor element_tensor(ctx->allocator({}),
170                                             tensorflow::DT_STRING, {});
171           element_tensor.scalar<tensorflow::tstring>()() = element;
172 
173           *end_of_sequence = false;
174           out_tensors->push_back(std::move(element_tensor));
175           return tensorflow::OkStatus();
176         } else {
177           *end_of_sequence = true;
178           if (maybe_element.status().code() == StatusCode::kOutOfRange) {
179             return tensorflow::OkStatus();
180           } else {
181             return ConvertToTensorFlowStatus(maybe_element.status());
182           }
183         }
184       }
185 
186      protected:
SaveInternal(tensorflow::data::SerializationContext * ctx,tensorflow::data::IteratorStateWriter * writer)187       tensorflow::Status SaveInternal(
188 // `::tensorflow::data::SerializationContext` argument was added on
189 // 2020-03-17 when `TF_GRAPH_DEF_VERSION` was defined to 343.
190 #if TF_GRAPH_DEF_VERSION > 343
191           tensorflow::data::SerializationContext* ctx,
192 #endif
193           tensorflow::data::IteratorStateWriter* writer) override {
194         return ::tensorflow::errors::Unimplemented(
195             "Save / Restore of an ExternalDataset iterator is not supported");
196       }
RestoreInternal(tensorflow::data::IteratorContext * ctx,tensorflow::data::IteratorStateReader * reader)197       tensorflow::Status RestoreInternal(
198           tensorflow::data::IteratorContext* ctx,
199           tensorflow::data::IteratorStateReader* reader) override {
200         return ::tensorflow::errors::Unimplemented(
201             "Save / Restore of an ExternalDataset iterator is not supported");
202       }
203 
204      private:
205       std::unique_ptr<ExternalDatasetIterator> stub_;
206       absl::Mutex mu_;
207     };
208 
209     // Private members of Dataset
210 
211     std::unique_ptr<ExternalDataset> stub_;
212   };
213 };
214 
215 REGISTER_OP("ExternalDataset")
216     .Input("token: string")
217     .Input("selector: string")
218     .Output("handle: variant")
219     .SetIsStateful()
220     .SetShapeFn(tensorflow::shape_inference::ScalarShape);
221 
222 REGISTER_KERNEL_BUILDER(Name("ExternalDataset").Device(tensorflow::DEVICE_CPU),
223                         ExternalDatasetOp);
224 
225 }  // namespace fcp
226