1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/repeat_dataset_op.h"
16
17 #include <utility>
18
19 #include "tensorflow/core/data/name_utils.h"
20 #include "tensorflow/core/framework/partial_tensor_shape.h"
21 #include "tensorflow/core/framework/tensor.h"
22
23 namespace tensorflow {
24 namespace data {
25
26 // See documentation in ../../ops/dataset_ops.cc for a high-level
27 // description of the following op.
28
29 /* static */ constexpr const char* const RepeatDatasetOp::kDatasetType;
30 /* static */ constexpr const char* const RepeatDatasetOp::kInputDataset;
31 /* static */ constexpr const char* const RepeatDatasetOp::kCount;
32 /* static */ constexpr const char* const RepeatDatasetOp::kOutputTypes;
33 /* static */ constexpr const char* const RepeatDatasetOp::kOutputShapes;
34
35 constexpr char kForeverRepeat[] = "ForeverRepeat";
36 constexpr char kEmptyRepeat[] = "EmptyRepeat";
37 constexpr char kFiniteRepeat[] = "FiniteRepeat";
38 constexpr char kCurIteration[] = "i";
39 constexpr char kInputImplEmpty[] = "input_impl_empty";
40 constexpr char kUninitialized[] = "uninitialized";
41 constexpr int64_t kKnownRatio = 1;
42
43 class RepeatDatasetOp::Dataset : public DatasetBase {
44 public:
Dataset(OpKernelContext * ctx,int64_t count,const DatasetBase * input)45 Dataset(OpKernelContext* ctx, int64_t count, const DatasetBase* input)
46 : DatasetBase(DatasetContext(ctx)), count_(count), input_(input) {
47 input_->Ref();
48 }
49
~Dataset()50 ~Dataset() override { input_->Unref(); }
51
MakeIteratorInternal(const string & prefix) const52 std::unique_ptr<IteratorBase> MakeIteratorInternal(
53 const string& prefix) const override {
54 if (count_ < 0) {
55 return std::make_unique<ForeverIterator>(ForeverIterator::Params{
56 this, name_utils::IteratorPrefix(kForeverRepeat, prefix)});
57 } else if (count_ == 0) {
58 return std::make_unique<EmptyIterator>(EmptyIterator::Params{
59 this, name_utils::IteratorPrefix(kEmptyRepeat, prefix)});
60 } else {
61 return std::make_unique<FiniteIterator>(FiniteIterator::Params{
62 this, name_utils::IteratorPrefix(kFiniteRepeat, prefix)});
63 }
64 }
65
output_dtypes() const66 const DataTypeVector& output_dtypes() const override {
67 return input_->output_dtypes();
68 }
output_shapes() const69 const std::vector<PartialTensorShape>& output_shapes() const override {
70 return input_->output_shapes();
71 }
72
DebugString() const73 string DebugString() const override {
74 return name_utils::DatasetDebugString(RepeatDatasetOp::kDatasetType);
75 }
76
CardinalityInternal() const77 int64_t CardinalityInternal() const override {
78 int64_t n = input_->Cardinality();
79 if (count_ < 0) {
80 if (n == 0) {
81 return 0;
82 }
83 return kInfiniteCardinality;
84 }
85 if (count_ == 0) {
86 return 0;
87 }
88 if (n == kInfiniteCardinality || n == kUnknownCardinality) {
89 return n;
90 }
91 return count_ * n;
92 }
93
CardinalityInternal(CardinalityOptions options) const94 int64_t CardinalityInternal(CardinalityOptions options) const override {
95 int64_t n = input_->Cardinality(options);
96 if (count_ < 0) {
97 if (n == 0) {
98 return 0;
99 }
100 return kInfiniteCardinality;
101 }
102 if (count_ == 0) {
103 return 0;
104 }
105 if (n == kInfiniteCardinality || n == kUnknownCardinality) {
106 return n;
107 }
108 return count_ * n;
109 }
110
InputDatasets(std::vector<const DatasetBase * > * inputs) const111 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
112 inputs->push_back(input_);
113 return OkStatus();
114 }
115
CheckExternalState() const116 Status CheckExternalState() const override {
117 return input_->CheckExternalState();
118 }
119
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const120 Status Get(OpKernelContext* ctx, int64 index,
121 std::vector<Tensor>* out_tensors) const override {
122 TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index));
123 return input_->Get(ctx, index % input_->Cardinality(), out_tensors);
124 }
125
126 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const127 Status AsGraphDefInternal(SerializationContext* ctx,
128 DatasetGraphDefBuilder* b,
129 Node** output) const override {
130 Node* input_graph_node = nullptr;
131 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
132 Node* count = nullptr;
133 TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
134 TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node, count}, output));
135 return OkStatus();
136 }
137
138 private:
139 class EmptyIterator : public DatasetIterator<Dataset> {
140 public:
EmptyIterator(const Params & params)141 explicit EmptyIterator(const Params& params)
142 : DatasetIterator<Dataset>(params) {}
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)143 Status GetNextInternal(IteratorContext* ctx,
144 std::vector<Tensor>* out_tensors,
145 bool* end_of_sequence) override {
146 *end_of_sequence = true;
147 return OkStatus();
148 }
149
150 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const151 std::shared_ptr<model::Node> CreateNode(
152 IteratorContext* ctx, model::Node::Args args) const override {
153 return model::MakeKnownRatioNode(std::move(args),
154 /*ratio=*/kKnownRatio);
155 }
156
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)157 Status SaveInternal(SerializationContext* ctx,
158 IteratorStateWriter* writer) override {
159 return OkStatus();
160 }
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)161 Status RestoreInternal(IteratorContext* ctx,
162 IteratorStateReader* reader) override {
163 return OkStatus();
164 }
165 };
166
167 class FiniteIterator : public DatasetIterator<Dataset> {
168 public:
FiniteIterator(const Params & params)169 explicit FiniteIterator(const Params& params)
170 : DatasetIterator<Dataset>(params), i_(0) {}
171
Initialize(IteratorContext * ctx)172 Status Initialize(IteratorContext* ctx) override {
173 return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
174 }
175
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)176 Status GetNextInternal(IteratorContext* ctx,
177 std::vector<Tensor>* out_tensors,
178 bool* end_of_sequence) override {
179 mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
180 if (!input_impl_) {
181 *end_of_sequence = true;
182 return OkStatus();
183 }
184 while (i_ < dataset()->count_) {
185 TF_RETURN_IF_ERROR(
186 input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
187 if (!*end_of_sequence) {
188 return OkStatus();
189 }
190 ++i_;
191 for (const auto& provider : ctx->split_providers()) {
192 TF_RETURN_IF_ERROR(provider->Reset());
193 }
194 TF_RETURN_IF_ERROR(
195 dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
196 }
197 *end_of_sequence = true;
198 input_impl_.reset();
199 return OkStatus();
200 }
201
202 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const203 std::shared_ptr<model::Node> CreateNode(
204 IteratorContext* ctx, model::Node::Args args) const override {
205 return model::MakeKnownRatioNode(std::move(args),
206 /*ratio=*/kKnownRatio);
207 }
208
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)209 Status SaveInternal(SerializationContext* ctx,
210 IteratorStateWriter* writer) override {
211 mutex_lock l(mu_);
212 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIteration), i_));
213 if (!input_impl_) {
214 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
215 } else {
216 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
217 }
218 return OkStatus();
219 }
220
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)221 Status RestoreInternal(IteratorContext* ctx,
222 IteratorStateReader* reader) override {
223 mutex_lock l(mu_);
224 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurIteration), &i_));
225 if (!reader->Contains(full_name(kInputImplEmpty))) {
226 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
227 } else {
228 input_impl_.reset();
229 }
230 return OkStatus();
231 }
232
233 private:
234 mutex mu_;
235 int64_t i_ TF_GUARDED_BY(mu_);
236 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
237 };
238
239 class ForeverIterator : public DatasetIterator<Dataset> {
240 public:
ForeverIterator(const Params & params)241 explicit ForeverIterator(const Params& params)
242 : DatasetIterator<Dataset>(params),
243 input_impl_(nullptr),
244 first_call_(true) {}
245
Initialize(IteratorContext * ctx)246 Status Initialize(IteratorContext* ctx) override {
247 mutex_lock l(mu_);
248 return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
249 }
250
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)251 Status GetNextInternal(IteratorContext* ctx,
252 std::vector<Tensor>* out_tensors,
253 bool* end_of_sequence) override {
254 mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
255 do {
256 if (!input_impl_) {
257 TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
258 ctx, this, prefix(), &input_impl_));
259 }
260 TF_RETURN_IF_ERROR(
261 input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
262 DCHECK(!*end_of_sequence || out_tensors->empty());
263 if (first_call_ && *end_of_sequence && ctx->split_providers().empty()) {
264 // If the first call to GetNext() fails because the end of sequence
265 // has been reached, we terminate the iteration immediately.
266 // Otherwise, this iterator would loop infinitely and never produce a
267 // value.
268 input_impl_.reset();
269 return OkStatus();
270 }
271 first_call_ = false;
272 if (!*end_of_sequence) {
273 return OkStatus();
274 }
275 for (const auto& provider : ctx->split_providers()) {
276 TF_RETURN_IF_ERROR(provider->Reset());
277 }
278 input_impl_.reset();
279 first_call_ = true;
280 } while (true);
281 }
282
283 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const284 std::shared_ptr<model::Node> CreateNode(
285 IteratorContext* ctx, model::Node::Args args) const override {
286 return model::MakeKnownRatioNode(std::move(args),
287 /*ratio=*/kKnownRatio);
288 }
289
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)290 Status SaveInternal(SerializationContext* ctx,
291 IteratorStateWriter* writer) override {
292 mutex_lock l(mu_);
293 if (!first_call_)
294 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
295 else
296 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kUninitialized), ""));
297 return OkStatus();
298 }
299
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)300 Status RestoreInternal(IteratorContext* ctx,
301 IteratorStateReader* reader) override {
302 mutex_lock l(mu_);
303 if (reader->Contains(full_name(kUninitialized))) {
304 input_impl_.reset();
305 first_call_ = true;
306 } else {
307 TF_RETURN_IF_ERROR(
308 dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
309 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
310 first_call_ = false;
311 }
312 return OkStatus();
313 }
314
315 private:
316 mutex mu_;
317 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
318 bool first_call_ TF_GUARDED_BY(mu_);
319 };
320
321 const int64_t count_;
322 const DatasetBase* const input_;
323 };
324
RepeatDatasetOp(OpKernelConstruction * ctx)325 RepeatDatasetOp::RepeatDatasetOp(OpKernelConstruction* ctx)
326 : UnaryDatasetOpKernel(ctx) {}
327
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)328 void RepeatDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
329 DatasetBase** output) {
330 // Create a new RepeatDatasetOp::Dataset, insert it in the step-local
331 // container, and return it as the output.
332 int64_t count;
333 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kCount, &count));
334 *output = new Dataset(ctx, count, input);
335 }
336
337 namespace {
338 REGISTER_KERNEL_BUILDER(Name("RepeatDataset").Device(DEVICE_CPU),
339 RepeatDatasetOp);
340 } // namespace
341 } // namespace data
342 } // namespace tensorflow
343