1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/experimental/list_dataset_op.h"
16
17 #include <memory>
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "tensorflow/core/data/dataset_utils.h"
23 #include "tensorflow/core/data/name_utils.h"
24 #include "tensorflow/core/data/split_utils.h"
25 #include "tensorflow/core/framework/partial_tensor_shape.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_util.h"
28 #include "tensorflow/core/graph/graph.h"
29 #include "tensorflow/core/util/batch_util.h"
30
31 namespace tensorflow {
32 namespace data {
33
34 // See documentation in ../../ops/dataset_ops.cc for a high-level
35 // description of the following op.
36
37 /* static */ constexpr const char* const ListDatasetOp::kDatasetType;
38 /* static */ constexpr const char* const ListDatasetOp::kTensors;
39 /* static */ constexpr const char* const ListDatasetOp::kTinputTypes;
40 /* static */ constexpr const char* const ListDatasetOp::kOutputTypes;
41 /* static */ constexpr const char* const ListDatasetOp::kOutputShapes;
42
43 class ListDatasetOp::Dataset : public DatasetBase {
44 public:
Dataset(OpKernelContext * ctx,std::vector<Tensor> tensors,const DataTypeVector & input_types,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes,int num_components)45 Dataset(OpKernelContext* ctx, std::vector<Tensor> tensors,
46 const DataTypeVector& input_types, const DataTypeVector& output_types,
47 const std::vector<PartialTensorShape>& output_shapes,
48 int num_components)
49 : DatasetBase(DatasetContext(ctx)),
50 tensors_(std::move(tensors)),
51 num_elements_(tensors_.size() / num_components),
52 num_components_(num_components),
53 input_types_(input_types),
54 output_types_(output_types),
55 output_shapes_(output_shapes) {}
56
MakeIteratorInternal(const string & prefix) const57 std::unique_ptr<IteratorBase> MakeIteratorInternal(
58 const string& prefix) const override {
59 return std::make_unique<Iterator>(Iterator::Params{
60 this, name_utils::IteratorPrefix(kDatasetType, prefix)});
61 }
62
MakeSplitProviders(std::vector<std::unique_ptr<SplitProvider>> * split_providers) const63 Status MakeSplitProviders(std::vector<std::unique_ptr<SplitProvider>>*
64 split_providers) const override {
65 split_providers->push_back(
66 std::make_unique<IndexSplitProvider>(num_elements_));
67 return Status::OK();
68 }
69
output_dtypes() const70 const DataTypeVector& output_dtypes() const override { return output_types_; }
71
output_shapes() const72 const std::vector<PartialTensorShape>& output_shapes() const override {
73 return output_shapes_;
74 }
75
DebugString() const76 string DebugString() const override {
77 return name_utils::DatasetDebugString(kDatasetType);
78 }
79
CardinalityInternal() const80 int64_t CardinalityInternal() const override { return num_elements_; }
81
CardinalityInternal(CardinalityOptions options) const82 int64_t CardinalityInternal(CardinalityOptions options) const override {
83 return num_elements_;
84 }
85
InputDatasets(std::vector<const DatasetBase * > * inputs) const86 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
87 return Status::OK();
88 }
89
CheckExternalState() const90 Status CheckExternalState() const override { return Status::OK(); }
91
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const92 Status Get(OpKernelContext* ctx, int64 index,
93 std::vector<Tensor>* out_tensors) const override {
94 TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index));
95 out_tensors->clear();
96 out_tensors->reserve(num_components_);
97 for (int i = 0; i < num_components_; ++i) {
98 out_tensors->push_back(tensors_[i + num_components_ * index]);
99 }
100 return Status::OK();
101 }
102
103 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const104 Status AsGraphDefInternal(SerializationContext* ctx,
105 DatasetGraphDefBuilder* b,
106 Node** output) const override {
107 std::vector<Node*> tensors;
108 tensors.reserve(tensors_.size());
109 for (const Tensor& t : tensors_) {
110 Node* node;
111 if (!ctx->is_graph_rewrite()) {
112 TF_RETURN_IF_ERROR(b->AddDatasetOrTensor(ctx, t, &node));
113 } else {
114 TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node));
115 DCHECK_NE(ctx->input_list(), nullptr);
116 ctx->input_list()->emplace_back(node->name(), t);
117 }
118 tensors.emplace_back(node);
119 }
120 AttrValue input_types;
121 b->BuildAttrValue(input_types_, &input_types);
122 TF_RETURN_IF_ERROR(b->AddDataset(this, {}, {{0, tensors}},
123 {{kTinputTypes, input_types}}, output));
124 return Status::OK();
125 }
126
127 private:
128 class Iterator : public DatasetIterator<Dataset> {
129 public:
Iterator(const Params & params)130 explicit Iterator(const Params& params)
131 : DatasetIterator<Dataset>(params) {}
132
Initialize(IteratorContext * ctx)133 Status Initialize(IteratorContext* ctx) override {
134 if (ctx->split_providers().empty()) {
135 split_provider_ =
136 std::make_shared<IndexSplitProvider>(dataset()->num_elements_);
137 } else {
138 TF_ASSIGN_OR_RETURN(split_provider_,
139 GetSingleSplitProvider(ctx, dataset()));
140 }
141 return Status::OK();
142 }
143
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)144 Status GetNextInternal(IteratorContext* ctx,
145 std::vector<Tensor>* out_tensors,
146 bool* end_of_sequence) override {
147 Tensor split;
148 TF_RETURN_IF_ERROR(split_provider_->GetNext(&split, end_of_sequence));
149 if (*end_of_sequence) {
150 return Status::OK();
151 }
152 int64_t index = split.scalar<int64_t>()();
153 out_tensors->reserve(dataset()->num_components_);
154 for (size_t i = 0; i < dataset()->num_components_; ++i) {
155 out_tensors->push_back(
156 dataset()->tensors_[i + dataset()->num_components_ * index]);
157 }
158 *end_of_sequence = false;
159 return Status::OK();
160 }
161
162 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const163 std::shared_ptr<model::Node> CreateNode(
164 IteratorContext* ctx, model::Node::Args args) const override {
165 return model::MakeSourceNode(std::move(args));
166 }
167
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)168 Status SaveInternal(SerializationContext* ctx,
169 IteratorStateWriter* writer) override {
170 return split_provider_->Save(
171 [this](const std::string& key) { return full_name(key); }, writer);
172 }
173
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)174 Status RestoreInternal(IteratorContext* ctx,
175 IteratorStateReader* reader) override {
176 return split_provider_->Restore(
177 [this](const std::string& key) { return full_name(key); }, reader);
178 }
179
180 private:
181 std::shared_ptr<SplitProvider> split_provider_;
182 };
183
184 const std::vector<Tensor> tensors_;
185 int64 num_elements_;
186 size_t num_components_;
187 DataTypeVector input_types_;
188 DataTypeVector output_types_;
189 std::vector<PartialTensorShape> output_shapes_;
190 };
191
ListDatasetOp(OpKernelConstruction * ctx)192 ListDatasetOp::ListDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
193 OP_REQUIRES_OK(ctx, ctx->GetAttr(kTinputTypes, &input_types_));
194 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
195 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
196 }
197
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)198 void ListDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase** output) {
199 OpInputList inputs;
200 OP_REQUIRES_OK(ctx, ctx->input_list(kTensors, &inputs));
201 std::vector<Tensor> tensors(inputs.begin(), inputs.end());
202 *output = new Dataset(ctx, std::move(tensors), input_types_, output_types_,
203 output_shapes_, output_shapes_.size());
204 OP_REQUIRES_OK(ctx,
205 VerifyTypesMatch((*output)->output_dtypes(), output_types_));
206 OP_REQUIRES_OK(
207 ctx, VerifyShapesCompatible((*output)->output_shapes(), output_shapes_));
208 }
209
210 namespace {
211
212 REGISTER_KERNEL_BUILDER(Name("ListDataset").Device(DEVICE_CPU), ListDatasetOp);
213 } // namespace
214 } // namespace data
215 } // namespace tensorflow
216