1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/take_dataset_op.h"
16
17 #include "tensorflow/core/data/name_utils.h"
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/partial_tensor_shape.h"
20 #include "tensorflow/core/framework/tensor.h"
21
22 namespace tensorflow {
23 namespace data {
24
25 /* static */ constexpr const char* const TakeDatasetOp::kDatasetType;
26 /* static */ constexpr const char* const TakeDatasetOp::kInputDataset;
27 /* static */ constexpr const char* const TakeDatasetOp::kCount;
28 /* static */ constexpr const char* const TakeDatasetOp::kOutputTypes;
29 /* static */ constexpr const char* const TakeDatasetOp::kOutputShapes;
30
31 constexpr char kCurIndex[] = "i";
32 constexpr char kInputImplEmpty[] = "input_impl_empty";
33 constexpr char kEmptyTake[] = "EmptyTake";
34 constexpr char kFiniteTake[] = "FiniteTake";
35
TakeDataset(OpKernelContext * ctx,int64_t count,const DatasetBase * input)36 TakeDataset::TakeDataset(OpKernelContext* ctx, int64_t count,
37 const DatasetBase* input)
38 : DatasetBase(DatasetContext(ctx)), count_(count), input_(input) {
39 input_->Ref();
40 }
41
TakeDataset(DatasetContext::Params params,int64_t count,const DatasetBase * input)42 TakeDataset::TakeDataset(DatasetContext::Params params, int64_t count,
43 const DatasetBase* input)
44 : DatasetBase(DatasetContext(std::move(params))),
45 count_(count),
46 input_(input) {
47 input_->Ref();
48 }
49
~TakeDataset()50 TakeDataset::~TakeDataset() { input_->Unref(); }
51
output_dtypes() const52 const DataTypeVector& TakeDataset::output_dtypes() const {
53 return input_->output_dtypes();
54 }
55
output_shapes() const56 const std::vector<PartialTensorShape>& TakeDataset::output_shapes() const {
57 return input_->output_shapes();
58 }
59
DebugString() const60 string TakeDataset::DebugString() const {
61 return name_utils::DatasetDebugString(TakeDatasetOp::kDatasetType);
62 }
63
CardinalityInternal() const64 int64_t TakeDataset::CardinalityInternal() const {
65 int64_t n = input_->Cardinality();
66 if (n == kUnknownCardinality) {
67 return kUnknownCardinality;
68 }
69 if (n == kInfiniteCardinality) {
70 return count_;
71 } else if (count_ == kInfiniteCardinality) {
72 return n;
73 }
74 return std::min(n, count_);
75 }
76
CardinalityInternal(CardinalityOptions options) const77 int64_t TakeDataset::CardinalityInternal(CardinalityOptions options) const {
78 int64_t n = input_->Cardinality(options);
79 if (n == kUnknownCardinality) {
80 return kUnknownCardinality;
81 }
82 if (n == kInfiniteCardinality) {
83 return count_;
84 } else if (count_ == kInfiniteCardinality) {
85 return n;
86 }
87
88 return std::min(n, count_);
89 }
90
InputDatasets(std::vector<const DatasetBase * > * inputs) const91 Status TakeDataset::InputDatasets(
92 std::vector<const DatasetBase*>* inputs) const {
93 inputs->push_back(input_);
94 return OkStatus();
95 }
96
CheckExternalState() const97 Status TakeDataset::CheckExternalState() const {
98 return input_->CheckExternalState();
99 }
100
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const101 Status TakeDataset::Get(OpKernelContext* ctx, int64 index,
102 std::vector<Tensor>* out_tensors) const {
103 TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index));
104 return input_->Get(ctx, index, out_tensors);
105 }
106
107 class TakeDataset::EmptyIterator : public DatasetIterator<TakeDataset> {
108 public:
EmptyIterator(const Params & params)109 explicit EmptyIterator(const Params& params)
110 : DatasetIterator<TakeDataset>(params) {}
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)111 Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
112 bool* end_of_sequence) override {
113 *end_of_sequence = true;
114 return OkStatus();
115 }
116
117 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const118 std::shared_ptr<model::Node> CreateNode(
119 IteratorContext* ctx, model::Node::Args args) const override {
120 return model::MakeKnownRatioNode(std::move(args),
121 /*ratio=*/1);
122 }
123
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)124 Status SaveInternal(SerializationContext* ctx,
125 IteratorStateWriter* writer) override {
126 return OkStatus();
127 }
128
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)129 Status RestoreInternal(IteratorContext* ctx,
130 IteratorStateReader* reader) override {
131 return OkStatus();
132 }
133 };
134
135 class TakeDataset::FiniteIterator : public DatasetIterator<TakeDataset> {
136 public:
FiniteIterator(const Params & params)137 explicit FiniteIterator(const Params& params)
138 : DatasetIterator<TakeDataset>(params), i_(0) {}
139
Initialize(IteratorContext * ctx)140 Status Initialize(IteratorContext* ctx) override {
141 return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
142 }
143
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)144 Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
145 bool* end_of_sequence) override {
146 mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
147 if (!input_impl_) {
148 *end_of_sequence = true;
149 return OkStatus();
150 }
151 while (dataset()->count_ < 0 || i_ < dataset()->count_) {
152 TF_RETURN_IF_ERROR(
153 input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
154 if (!*end_of_sequence) {
155 ++i_;
156 return OkStatus();
157 }
158 break;
159 }
160 *end_of_sequence = true;
161 input_impl_.reset();
162 return OkStatus();
163 }
164
165 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const166 std::shared_ptr<model::Node> CreateNode(
167 IteratorContext* ctx, model::Node::Args args) const override {
168 return model::MakeKnownRatioNode(std::move(args),
169 /*ratio=*/1);
170 }
171
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)172 Status SaveInternal(SerializationContext* ctx,
173 IteratorStateWriter* writer) override {
174 mutex_lock l(mu_);
175 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIndex), i_));
176 if (input_impl_) {
177 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
178 } else {
179 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
180 }
181 return OkStatus();
182 }
183
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)184 Status RestoreInternal(IteratorContext* ctx,
185 IteratorStateReader* reader) override {
186 mutex_lock l(mu_);
187 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurIndex), &i_));
188 if (!reader->Contains(full_name(kInputImplEmpty))) {
189 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
190 } else {
191 input_impl_.reset();
192 }
193 return OkStatus();
194 }
195
196 private:
197 mutex mu_;
198 int64_t i_ TF_GUARDED_BY(mu_);
199 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
200 };
201
202 // See documentation in ../../ops/dataset_ops.cc for a high-level
203 // description of the following op.
MakeIteratorInternal(const string & prefix) const204 std::unique_ptr<IteratorBase> TakeDataset::MakeIteratorInternal(
205 const string& prefix) const {
206 if (count_ == 0) {
207 return std::make_unique<EmptyIterator>(EmptyIterator::Params{
208 this, name_utils::IteratorPrefix(kEmptyTake, prefix)});
209 } else {
210 return std::make_unique<FiniteIterator>(FiniteIterator::Params{
211 this, name_utils::IteratorPrefix(kFiniteTake, prefix)});
212 }
213 }
214
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const215 Status TakeDataset::AsGraphDefInternal(SerializationContext* ctx,
216 DatasetGraphDefBuilder* b,
217 Node** output) const {
218 Node* input_graph_node = nullptr;
219 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
220 Node* count = nullptr;
221 TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
222 TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node, count}, output));
223 return OkStatus();
224 }
225
TakeDatasetOp(OpKernelConstruction * ctx)226 TakeDatasetOp::TakeDatasetOp(OpKernelConstruction* ctx)
227 : UnaryDatasetOpKernel(ctx) {}
228
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)229 void TakeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
230 DatasetBase** output) {
231 // Create a new TakeDatasetOp::Dataset, and return it as the output.
232 int64_t count;
233 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kCount, &count));
234 *output = new TakeDataset(ctx, count, input);
235 }
236
237 namespace {
238 REGISTER_KERNEL_BUILDER(Name("TakeDataset").Device(DEVICE_CPU), TakeDatasetOp);
239 } // namespace
240 } // namespace data
241 } // namespace tensorflow
242