xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/data/experimental/list_dataset_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/experimental/list_dataset_op.h"
16 
17 #include <memory>
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "tensorflow/core/data/dataset_utils.h"
23 #include "tensorflow/core/data/name_utils.h"
24 #include "tensorflow/core/data/split_utils.h"
25 #include "tensorflow/core/framework/partial_tensor_shape.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_util.h"
28 #include "tensorflow/core/graph/graph.h"
29 #include "tensorflow/core/util/batch_util.h"
30 
31 namespace tensorflow {
32 namespace data {
33 
34 // See documentation in ../../ops/dataset_ops.cc for a high-level
35 // description of the following op.
36 
37 /* static */ constexpr const char* const ListDatasetOp::kDatasetType;
38 /* static */ constexpr const char* const ListDatasetOp::kTensors;
39 /* static */ constexpr const char* const ListDatasetOp::kTinputTypes;
40 /* static */ constexpr const char* const ListDatasetOp::kOutputTypes;
41 /* static */ constexpr const char* const ListDatasetOp::kOutputShapes;
42 
43 class ListDatasetOp::Dataset : public DatasetBase {
44  public:
Dataset(OpKernelContext * ctx,std::vector<Tensor> tensors,const DataTypeVector & input_types,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes,int num_components)45   Dataset(OpKernelContext* ctx, std::vector<Tensor> tensors,
46           const DataTypeVector& input_types, const DataTypeVector& output_types,
47           const std::vector<PartialTensorShape>& output_shapes,
48           int num_components)
49       : DatasetBase(DatasetContext(ctx)),
50         tensors_(std::move(tensors)),
51         num_elements_(tensors_.size() / num_components),
52         num_components_(num_components),
53         input_types_(input_types),
54         output_types_(output_types),
55         output_shapes_(output_shapes) {}
56 
MakeIteratorInternal(const string & prefix) const57   std::unique_ptr<IteratorBase> MakeIteratorInternal(
58       const string& prefix) const override {
59     return std::make_unique<Iterator>(Iterator::Params{
60         this, name_utils::IteratorPrefix(kDatasetType, prefix)});
61   }
62 
MakeSplitProviders(std::vector<std::unique_ptr<SplitProvider>> * split_providers) const63   Status MakeSplitProviders(std::vector<std::unique_ptr<SplitProvider>>*
64                                 split_providers) const override {
65     split_providers->push_back(
66         std::make_unique<IndexSplitProvider>(num_elements_));
67     return Status::OK();
68   }
69 
output_dtypes() const70   const DataTypeVector& output_dtypes() const override { return output_types_; }
71 
output_shapes() const72   const std::vector<PartialTensorShape>& output_shapes() const override {
73     return output_shapes_;
74   }
75 
DebugString() const76   string DebugString() const override {
77     return name_utils::DatasetDebugString(kDatasetType);
78   }
79 
CardinalityInternal() const80   int64_t CardinalityInternal() const override { return num_elements_; }
81 
CardinalityInternal(CardinalityOptions options) const82   int64_t CardinalityInternal(CardinalityOptions options) const override {
83     return num_elements_;
84   }
85 
InputDatasets(std::vector<const DatasetBase * > * inputs) const86   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
87     return Status::OK();
88   }
89 
CheckExternalState() const90   Status CheckExternalState() const override { return Status::OK(); }
91 
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const92   Status Get(OpKernelContext* ctx, int64 index,
93              std::vector<Tensor>* out_tensors) const override {
94     TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index));
95     out_tensors->clear();
96     out_tensors->reserve(num_components_);
97     for (int i = 0; i < num_components_; ++i) {
98       out_tensors->push_back(tensors_[i + num_components_ * index]);
99     }
100     return Status::OK();
101   }
102 
103  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const104   Status AsGraphDefInternal(SerializationContext* ctx,
105                             DatasetGraphDefBuilder* b,
106                             Node** output) const override {
107     std::vector<Node*> tensors;
108     tensors.reserve(tensors_.size());
109     for (const Tensor& t : tensors_) {
110       Node* node;
111       if (!ctx->is_graph_rewrite()) {
112         TF_RETURN_IF_ERROR(b->AddDatasetOrTensor(ctx, t, &node));
113       } else {
114         TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node));
115         DCHECK_NE(ctx->input_list(), nullptr);
116         ctx->input_list()->emplace_back(node->name(), t);
117       }
118       tensors.emplace_back(node);
119     }
120     AttrValue input_types;
121     b->BuildAttrValue(input_types_, &input_types);
122     TF_RETURN_IF_ERROR(b->AddDataset(this, {}, {{0, tensors}},
123                                      {{kTinputTypes, input_types}}, output));
124     return Status::OK();
125   }
126 
127  private:
128   class Iterator : public DatasetIterator<Dataset> {
129    public:
Iterator(const Params & params)130     explicit Iterator(const Params& params)
131         : DatasetIterator<Dataset>(params) {}
132 
Initialize(IteratorContext * ctx)133     Status Initialize(IteratorContext* ctx) override {
134       if (ctx->split_providers().empty()) {
135         split_provider_ =
136             std::make_shared<IndexSplitProvider>(dataset()->num_elements_);
137       } else {
138         TF_ASSIGN_OR_RETURN(split_provider_,
139                             GetSingleSplitProvider(ctx, dataset()));
140       }
141       return Status::OK();
142     }
143 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)144     Status GetNextInternal(IteratorContext* ctx,
145                            std::vector<Tensor>* out_tensors,
146                            bool* end_of_sequence) override {
147       Tensor split;
148       TF_RETURN_IF_ERROR(split_provider_->GetNext(&split, end_of_sequence));
149       if (*end_of_sequence) {
150         return Status::OK();
151       }
152       int64_t index = split.scalar<int64_t>()();
153       out_tensors->reserve(dataset()->num_components_);
154       for (size_t i = 0; i < dataset()->num_components_; ++i) {
155         out_tensors->push_back(
156             dataset()->tensors_[i + dataset()->num_components_ * index]);
157       }
158       *end_of_sequence = false;
159       return Status::OK();
160     }
161 
162    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const163     std::shared_ptr<model::Node> CreateNode(
164         IteratorContext* ctx, model::Node::Args args) const override {
165       return model::MakeSourceNode(std::move(args));
166     }
167 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)168     Status SaveInternal(SerializationContext* ctx,
169                         IteratorStateWriter* writer) override {
170       return split_provider_->Save(
171           [this](const std::string& key) { return full_name(key); }, writer);
172     }
173 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)174     Status RestoreInternal(IteratorContext* ctx,
175                            IteratorStateReader* reader) override {
176       return split_provider_->Restore(
177           [this](const std::string& key) { return full_name(key); }, reader);
178     }
179 
180    private:
181     std::shared_ptr<SplitProvider> split_provider_;
182   };
183 
184   const std::vector<Tensor> tensors_;
185   int64 num_elements_;
186   size_t num_components_;
187   DataTypeVector input_types_;
188   DataTypeVector output_types_;
189   std::vector<PartialTensorShape> output_shapes_;
190 };
191 
ListDatasetOp(OpKernelConstruction * ctx)192 ListDatasetOp::ListDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
193   OP_REQUIRES_OK(ctx, ctx->GetAttr(kTinputTypes, &input_types_));
194   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
195   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
196 }
197 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)198 void ListDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase** output) {
199   OpInputList inputs;
200   OP_REQUIRES_OK(ctx, ctx->input_list(kTensors, &inputs));
201   std::vector<Tensor> tensors(inputs.begin(), inputs.end());
202   *output = new Dataset(ctx, std::move(tensors), input_types_, output_types_,
203                         output_shapes_, output_shapes_.size());
204   OP_REQUIRES_OK(ctx,
205                  VerifyTypesMatch((*output)->output_dtypes(), output_types_));
206   OP_REQUIRES_OK(
207       ctx, VerifyShapesCompatible((*output)->output_shapes(), output_shapes_));
208 }
209 
210 namespace {
211 
212 REGISTER_KERNEL_BUILDER(Name("ListDataset").Device(DEVICE_CPU), ListDatasetOp);
213 }  // namespace
214 }  // namespace data
215 }  // namespace tensorflow
216