xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/data/take_dataset_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/take_dataset_op.h"
16 
17 #include "tensorflow/core/data/name_utils.h"
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/partial_tensor_shape.h"
20 #include "tensorflow/core/framework/tensor.h"
21 
22 namespace tensorflow {
23 namespace data {
24 
25 /* static */ constexpr const char* const TakeDatasetOp::kDatasetType;
26 /* static */ constexpr const char* const TakeDatasetOp::kInputDataset;
27 /* static */ constexpr const char* const TakeDatasetOp::kCount;
28 /* static */ constexpr const char* const TakeDatasetOp::kOutputTypes;
29 /* static */ constexpr const char* const TakeDatasetOp::kOutputShapes;
30 
31 constexpr char kCurIndex[] = "i";
32 constexpr char kInputImplEmpty[] = "input_impl_empty";
33 constexpr char kEmptyTake[] = "EmptyTake";
34 constexpr char kFiniteTake[] = "FiniteTake";
35 
TakeDataset(OpKernelContext * ctx,int64_t count,const DatasetBase * input)36 TakeDataset::TakeDataset(OpKernelContext* ctx, int64_t count,
37                          const DatasetBase* input)
38     : DatasetBase(DatasetContext(ctx)), count_(count), input_(input) {
39   input_->Ref();
40 }
41 
TakeDataset(DatasetContext::Params params,int64_t count,const DatasetBase * input)42 TakeDataset::TakeDataset(DatasetContext::Params params, int64_t count,
43                          const DatasetBase* input)
44     : DatasetBase(DatasetContext(std::move(params))),
45       count_(count),
46       input_(input) {
47   input_->Ref();
48 }
49 
~TakeDataset()50 TakeDataset::~TakeDataset() { input_->Unref(); }
51 
output_dtypes() const52 const DataTypeVector& TakeDataset::output_dtypes() const {
53   return input_->output_dtypes();
54 }
55 
output_shapes() const56 const std::vector<PartialTensorShape>& TakeDataset::output_shapes() const {
57   return input_->output_shapes();
58 }
59 
DebugString() const60 string TakeDataset::DebugString() const {
61   return name_utils::DatasetDebugString(TakeDatasetOp::kDatasetType);
62 }
63 
CardinalityInternal() const64 int64_t TakeDataset::CardinalityInternal() const {
65   int64_t n = input_->Cardinality();
66   if (n == kUnknownCardinality) {
67     return kUnknownCardinality;
68   }
69   if (n == kInfiniteCardinality) {
70     return count_;
71   } else if (count_ == kInfiniteCardinality) {
72     return n;
73   }
74   return std::min(n, count_);
75 }
76 
CardinalityInternal(CardinalityOptions options) const77 int64_t TakeDataset::CardinalityInternal(CardinalityOptions options) const {
78   int64_t n = input_->Cardinality(options);
79   if (n == kUnknownCardinality) {
80     return kUnknownCardinality;
81   }
82   if (n == kInfiniteCardinality) {
83     return count_;
84   } else if (count_ == kInfiniteCardinality) {
85     return n;
86   }
87 
88   return std::min(n, count_);
89 }
90 
InputDatasets(std::vector<const DatasetBase * > * inputs) const91 Status TakeDataset::InputDatasets(
92     std::vector<const DatasetBase*>* inputs) const {
93   inputs->push_back(input_);
94   return OkStatus();
95 }
96 
CheckExternalState() const97 Status TakeDataset::CheckExternalState() const {
98   return input_->CheckExternalState();
99 }
100 
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const101 Status TakeDataset::Get(OpKernelContext* ctx, int64 index,
102                         std::vector<Tensor>* out_tensors) const {
103   TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index));
104   return input_->Get(ctx, index, out_tensors);
105 }
106 
107 class TakeDataset::EmptyIterator : public DatasetIterator<TakeDataset> {
108  public:
EmptyIterator(const Params & params)109   explicit EmptyIterator(const Params& params)
110       : DatasetIterator<TakeDataset>(params) {}
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)111   Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
112                          bool* end_of_sequence) override {
113     *end_of_sequence = true;
114     return OkStatus();
115   }
116 
117  protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const118   std::shared_ptr<model::Node> CreateNode(
119       IteratorContext* ctx, model::Node::Args args) const override {
120     return model::MakeKnownRatioNode(std::move(args),
121                                      /*ratio=*/1);
122   }
123 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)124   Status SaveInternal(SerializationContext* ctx,
125                       IteratorStateWriter* writer) override {
126     return OkStatus();
127   }
128 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)129   Status RestoreInternal(IteratorContext* ctx,
130                          IteratorStateReader* reader) override {
131     return OkStatus();
132   }
133 };
134 
135 class TakeDataset::FiniteIterator : public DatasetIterator<TakeDataset> {
136  public:
FiniteIterator(const Params & params)137   explicit FiniteIterator(const Params& params)
138       : DatasetIterator<TakeDataset>(params), i_(0) {}
139 
Initialize(IteratorContext * ctx)140   Status Initialize(IteratorContext* ctx) override {
141     return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
142   }
143 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)144   Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
145                          bool* end_of_sequence) override {
146     mutex_lock l(mu_);  // TODO(mrry): Make locking less conservative.
147     if (!input_impl_) {
148       *end_of_sequence = true;
149       return OkStatus();
150     }
151     while (dataset()->count_ < 0 || i_ < dataset()->count_) {
152       TF_RETURN_IF_ERROR(
153           input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
154       if (!*end_of_sequence) {
155         ++i_;
156         return OkStatus();
157       }
158       break;
159     }
160     *end_of_sequence = true;
161     input_impl_.reset();
162     return OkStatus();
163   }
164 
165  protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const166   std::shared_ptr<model::Node> CreateNode(
167       IteratorContext* ctx, model::Node::Args args) const override {
168     return model::MakeKnownRatioNode(std::move(args),
169                                      /*ratio=*/1);
170   }
171 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)172   Status SaveInternal(SerializationContext* ctx,
173                       IteratorStateWriter* writer) override {
174     mutex_lock l(mu_);
175     TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIndex), i_));
176     if (input_impl_) {
177       TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
178     } else {
179       TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
180     }
181     return OkStatus();
182   }
183 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)184   Status RestoreInternal(IteratorContext* ctx,
185                          IteratorStateReader* reader) override {
186     mutex_lock l(mu_);
187     TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurIndex), &i_));
188     if (!reader->Contains(full_name(kInputImplEmpty))) {
189       TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
190     } else {
191       input_impl_.reset();
192     }
193     return OkStatus();
194   }
195 
196  private:
197   mutex mu_;
198   int64_t i_ TF_GUARDED_BY(mu_);
199   std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
200 };
201 
202 // See documentation in ../../ops/dataset_ops.cc for a high-level
203 // description of the following op.
MakeIteratorInternal(const string & prefix) const204 std::unique_ptr<IteratorBase> TakeDataset::MakeIteratorInternal(
205     const string& prefix) const {
206   if (count_ == 0) {
207     return std::make_unique<EmptyIterator>(EmptyIterator::Params{
208         this, name_utils::IteratorPrefix(kEmptyTake, prefix)});
209   } else {
210     return std::make_unique<FiniteIterator>(FiniteIterator::Params{
211         this, name_utils::IteratorPrefix(kFiniteTake, prefix)});
212   }
213 }
214 
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const215 Status TakeDataset::AsGraphDefInternal(SerializationContext* ctx,
216                                        DatasetGraphDefBuilder* b,
217                                        Node** output) const {
218   Node* input_graph_node = nullptr;
219   TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
220   Node* count = nullptr;
221   TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
222   TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node, count}, output));
223   return OkStatus();
224 }
225 
TakeDatasetOp(OpKernelConstruction * ctx)226 TakeDatasetOp::TakeDatasetOp(OpKernelConstruction* ctx)
227     : UnaryDatasetOpKernel(ctx) {}
228 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)229 void TakeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
230                                 DatasetBase** output) {
231   // Create a new TakeDatasetOp::Dataset, and return it as the output.
232   int64_t count;
233   OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kCount, &count));
234   *output = new TakeDataset(ctx, count, input);
235 }
236 
237 namespace {
238 REGISTER_KERNEL_BUILDER(Name("TakeDataset").Device(DEVICE_CPU), TakeDatasetOp);
239 }  // namespace
240 }  // namespace data
241 }  // namespace tensorflow
242