xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/data/repeat_dataset_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/repeat_dataset_op.h"
16 
17 #include <utility>
18 
19 #include "tensorflow/core/data/name_utils.h"
20 #include "tensorflow/core/framework/partial_tensor_shape.h"
21 #include "tensorflow/core/framework/tensor.h"
22 
23 namespace tensorflow {
24 namespace data {
25 
26 // See documentation in ../../ops/dataset_ops.cc for a high-level
27 // description of the following op.
28 
29 /* static */ constexpr const char* const RepeatDatasetOp::kDatasetType;
30 /* static */ constexpr const char* const RepeatDatasetOp::kInputDataset;
31 /* static */ constexpr const char* const RepeatDatasetOp::kCount;
32 /* static */ constexpr const char* const RepeatDatasetOp::kOutputTypes;
33 /* static */ constexpr const char* const RepeatDatasetOp::kOutputShapes;
34 
35 constexpr char kForeverRepeat[] = "ForeverRepeat";
36 constexpr char kEmptyRepeat[] = "EmptyRepeat";
37 constexpr char kFiniteRepeat[] = "FiniteRepeat";
38 constexpr char kCurIteration[] = "i";
39 constexpr char kInputImplEmpty[] = "input_impl_empty";
40 constexpr char kUninitialized[] = "uninitialized";
41 constexpr int64_t kKnownRatio = 1;
42 
43 class RepeatDatasetOp::Dataset : public DatasetBase {
44  public:
Dataset(OpKernelContext * ctx,int64_t count,const DatasetBase * input)45   Dataset(OpKernelContext* ctx, int64_t count, const DatasetBase* input)
46       : DatasetBase(DatasetContext(ctx)), count_(count), input_(input) {
47     input_->Ref();
48   }
49 
~Dataset()50   ~Dataset() override { input_->Unref(); }
51 
MakeIteratorInternal(const string & prefix) const52   std::unique_ptr<IteratorBase> MakeIteratorInternal(
53       const string& prefix) const override {
54     if (count_ < 0) {
55       return std::make_unique<ForeverIterator>(ForeverIterator::Params{
56           this, name_utils::IteratorPrefix(kForeverRepeat, prefix)});
57     } else if (count_ == 0) {
58       return std::make_unique<EmptyIterator>(EmptyIterator::Params{
59           this, name_utils::IteratorPrefix(kEmptyRepeat, prefix)});
60     } else {
61       return std::make_unique<FiniteIterator>(FiniteIterator::Params{
62           this, name_utils::IteratorPrefix(kFiniteRepeat, prefix)});
63     }
64   }
65 
output_dtypes() const66   const DataTypeVector& output_dtypes() const override {
67     return input_->output_dtypes();
68   }
output_shapes() const69   const std::vector<PartialTensorShape>& output_shapes() const override {
70     return input_->output_shapes();
71   }
72 
DebugString() const73   string DebugString() const override {
74     return name_utils::DatasetDebugString(RepeatDatasetOp::kDatasetType);
75   }
76 
CardinalityInternal() const77   int64_t CardinalityInternal() const override {
78     int64_t n = input_->Cardinality();
79     if (count_ < 0) {
80       if (n == 0) {
81         return 0;
82       }
83       return kInfiniteCardinality;
84     }
85     if (count_ == 0) {
86       return 0;
87     }
88     if (n == kInfiniteCardinality || n == kUnknownCardinality) {
89       return n;
90     }
91     return count_ * n;
92   }
93 
CardinalityInternal(CardinalityOptions options) const94   int64_t CardinalityInternal(CardinalityOptions options) const override {
95     int64_t n = input_->Cardinality(options);
96     if (count_ < 0) {
97       if (n == 0) {
98         return 0;
99       }
100       return kInfiniteCardinality;
101     }
102     if (count_ == 0) {
103       return 0;
104     }
105     if (n == kInfiniteCardinality || n == kUnknownCardinality) {
106       return n;
107     }
108     return count_ * n;
109   }
110 
InputDatasets(std::vector<const DatasetBase * > * inputs) const111   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
112     inputs->push_back(input_);
113     return OkStatus();
114   }
115 
CheckExternalState() const116   Status CheckExternalState() const override {
117     return input_->CheckExternalState();
118   }
119 
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const120   Status Get(OpKernelContext* ctx, int64 index,
121              std::vector<Tensor>* out_tensors) const override {
122     TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index));
123     return input_->Get(ctx, index % input_->Cardinality(), out_tensors);
124   }
125 
126  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const127   Status AsGraphDefInternal(SerializationContext* ctx,
128                             DatasetGraphDefBuilder* b,
129                             Node** output) const override {
130     Node* input_graph_node = nullptr;
131     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
132     Node* count = nullptr;
133     TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
134     TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node, count}, output));
135     return OkStatus();
136   }
137 
138  private:
139   class EmptyIterator : public DatasetIterator<Dataset> {
140    public:
EmptyIterator(const Params & params)141     explicit EmptyIterator(const Params& params)
142         : DatasetIterator<Dataset>(params) {}
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)143     Status GetNextInternal(IteratorContext* ctx,
144                            std::vector<Tensor>* out_tensors,
145                            bool* end_of_sequence) override {
146       *end_of_sequence = true;
147       return OkStatus();
148     }
149 
150    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const151     std::shared_ptr<model::Node> CreateNode(
152         IteratorContext* ctx, model::Node::Args args) const override {
153       return model::MakeKnownRatioNode(std::move(args),
154                                        /*ratio=*/kKnownRatio);
155     }
156 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)157     Status SaveInternal(SerializationContext* ctx,
158                         IteratorStateWriter* writer) override {
159       return OkStatus();
160     }
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)161     Status RestoreInternal(IteratorContext* ctx,
162                            IteratorStateReader* reader) override {
163       return OkStatus();
164     }
165   };
166 
167   class FiniteIterator : public DatasetIterator<Dataset> {
168    public:
FiniteIterator(const Params & params)169     explicit FiniteIterator(const Params& params)
170         : DatasetIterator<Dataset>(params), i_(0) {}
171 
Initialize(IteratorContext * ctx)172     Status Initialize(IteratorContext* ctx) override {
173       return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
174     }
175 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)176     Status GetNextInternal(IteratorContext* ctx,
177                            std::vector<Tensor>* out_tensors,
178                            bool* end_of_sequence) override {
179       mutex_lock l(mu_);  // TODO(mrry): Make locking less conservative.
180       if (!input_impl_) {
181         *end_of_sequence = true;
182         return OkStatus();
183       }
184       while (i_ < dataset()->count_) {
185         TF_RETURN_IF_ERROR(
186             input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
187         if (!*end_of_sequence) {
188           return OkStatus();
189         }
190         ++i_;
191         for (const auto& provider : ctx->split_providers()) {
192           TF_RETURN_IF_ERROR(provider->Reset());
193         }
194         TF_RETURN_IF_ERROR(
195             dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
196       }
197       *end_of_sequence = true;
198       input_impl_.reset();
199       return OkStatus();
200     }
201 
202    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const203     std::shared_ptr<model::Node> CreateNode(
204         IteratorContext* ctx, model::Node::Args args) const override {
205       return model::MakeKnownRatioNode(std::move(args),
206                                        /*ratio=*/kKnownRatio);
207     }
208 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)209     Status SaveInternal(SerializationContext* ctx,
210                         IteratorStateWriter* writer) override {
211       mutex_lock l(mu_);
212       TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIteration), i_));
213       if (!input_impl_) {
214         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
215       } else {
216         TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
217       }
218       return OkStatus();
219     }
220 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)221     Status RestoreInternal(IteratorContext* ctx,
222                            IteratorStateReader* reader) override {
223       mutex_lock l(mu_);
224       TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurIteration), &i_));
225       if (!reader->Contains(full_name(kInputImplEmpty))) {
226         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
227       } else {
228         input_impl_.reset();
229       }
230       return OkStatus();
231     }
232 
233    private:
234     mutex mu_;
235     int64_t i_ TF_GUARDED_BY(mu_);
236     std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
237   };
238 
239   class ForeverIterator : public DatasetIterator<Dataset> {
240    public:
ForeverIterator(const Params & params)241     explicit ForeverIterator(const Params& params)
242         : DatasetIterator<Dataset>(params),
243           input_impl_(nullptr),
244           first_call_(true) {}
245 
Initialize(IteratorContext * ctx)246     Status Initialize(IteratorContext* ctx) override {
247       mutex_lock l(mu_);
248       return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
249     }
250 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)251     Status GetNextInternal(IteratorContext* ctx,
252                            std::vector<Tensor>* out_tensors,
253                            bool* end_of_sequence) override {
254       mutex_lock l(mu_);  // TODO(mrry): Make locking less conservative.
255       do {
256         if (!input_impl_) {
257           TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
258               ctx, this, prefix(), &input_impl_));
259         }
260         TF_RETURN_IF_ERROR(
261             input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
262         DCHECK(!*end_of_sequence || out_tensors->empty());
263         if (first_call_ && *end_of_sequence && ctx->split_providers().empty()) {
264           // If the first call to GetNext() fails because the end of sequence
265           // has been reached, we terminate the iteration immediately.
266           // Otherwise, this iterator would loop infinitely and never produce a
267           // value.
268           input_impl_.reset();
269           return OkStatus();
270         }
271         first_call_ = false;
272         if (!*end_of_sequence) {
273           return OkStatus();
274         }
275         for (const auto& provider : ctx->split_providers()) {
276           TF_RETURN_IF_ERROR(provider->Reset());
277         }
278         input_impl_.reset();
279         first_call_ = true;
280       } while (true);
281     }
282 
283    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const284     std::shared_ptr<model::Node> CreateNode(
285         IteratorContext* ctx, model::Node::Args args) const override {
286       return model::MakeKnownRatioNode(std::move(args),
287                                        /*ratio=*/kKnownRatio);
288     }
289 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)290     Status SaveInternal(SerializationContext* ctx,
291                         IteratorStateWriter* writer) override {
292       mutex_lock l(mu_);
293       if (!first_call_)
294         TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
295       else
296         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kUninitialized), ""));
297       return OkStatus();
298     }
299 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)300     Status RestoreInternal(IteratorContext* ctx,
301                            IteratorStateReader* reader) override {
302       mutex_lock l(mu_);
303       if (reader->Contains(full_name(kUninitialized))) {
304         input_impl_.reset();
305         first_call_ = true;
306       } else {
307         TF_RETURN_IF_ERROR(
308             dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
309         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
310         first_call_ = false;
311       }
312       return OkStatus();
313     }
314 
315    private:
316     mutex mu_;
317     std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
318     bool first_call_ TF_GUARDED_BY(mu_);
319   };
320 
321   const int64_t count_;
322   const DatasetBase* const input_;
323 };
324 
RepeatDatasetOp(OpKernelConstruction * ctx)325 RepeatDatasetOp::RepeatDatasetOp(OpKernelConstruction* ctx)
326     : UnaryDatasetOpKernel(ctx) {}
327 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)328 void RepeatDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
329                                   DatasetBase** output) {
330   // Create a new RepeatDatasetOp::Dataset, insert it in the step-local
331   // container, and return it as the output.
332   int64_t count;
333   OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kCount, &count));
334   *output = new Dataset(ctx, count, input);
335 }
336 
337 namespace {
338 REGISTER_KERNEL_BUILDER(Name("RepeatDataset").Device(DEVICE_CPU),
339                         RepeatDatasetOp);
340 }  // namespace
341 }  // namespace data
342 }  // namespace tensorflow
343