xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/data/concatenate_dataset_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/concatenate_dataset_op.h"
16 
17 #include <string>
18 #include <utility>
19 
20 #include "tensorflow/core/data/name_utils.h"
21 #include "tensorflow/core/data/split_utils.h"
22 #include "tensorflow/core/framework/partial_tensor_shape.h"
23 #include "tensorflow/core/framework/tensor.h"
24 
25 namespace tensorflow {
26 namespace data {
27 
28 // See documentation in ../../ops/dataset_ops.cc for a high-level
29 // description of the following op.
30 
31 /* static */ constexpr const char* const ConcatenateDatasetOp::kDatasetType;
32 /* static */ constexpr const char* const ConcatenateDatasetOp::kInputDataset;
33 /* static */ constexpr const char* const ConcatenateDatasetOp::kAnotherDataset;
34 /* static */ constexpr const char* const ConcatenateDatasetOp::kOutputTypes;
35 /* static */ constexpr const char* const ConcatenateDatasetOp::kOutputShapes;
36 
37 constexpr char kIndex[] = "i";
38 constexpr char kInputImplUninitialized[] = "input_impl_uninitialized";
39 
40 class ConcatenateDatasetOp::Dataset : public DatasetBase {
41  public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,const DatasetBase * to_concatenate)42   explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
43                    const DatasetBase* to_concatenate)
44       : DatasetBase(DatasetContext(ctx)),
45         input_(input),
46         to_concatenate_(to_concatenate),
47         input_cardinality_(input->Cardinality()),
48         to_concatenate_cardinality_(to_concatenate_->Cardinality()) {
49     input_->Ref();
50     to_concatenate_->Ref();
51 
52     auto os_input = input->output_shapes();
53     auto os_concatenate = to_concatenate->output_shapes();
54     for (int i = 0; i < os_input.size(); i++) {
55       output_shapes_.push_back(
56           MostSpecificCompatibleShape(os_input[i], os_concatenate[i]));
57     }
58   }
~Dataset()59   ~Dataset() override {
60     input_->Unref();
61     to_concatenate_->Unref();
62   }
63 
MakeIteratorInternal(const string & prefix) const64   std::unique_ptr<IteratorBase> MakeIteratorInternal(
65       const string& prefix) const override {
66     return std::make_unique<Iterator>(Iterator::Params{
67         this, name_utils::IteratorPrefix(kDatasetType, prefix)});
68   }
69 
MakeSplitProviders(std::vector<std::unique_ptr<SplitProvider>> * split_providers) const70   Status MakeSplitProviders(std::vector<std::unique_ptr<SplitProvider>>*
71                                 split_providers) const override {
72     TF_ASSIGN_OR_RETURN(*split_providers, GetSplitProviders(this));
73     return OkStatus();
74   }
75 
output_dtypes() const76   const DataTypeVector& output_dtypes() const override {
77     return input_->output_dtypes();
78   }
79 
output_shapes() const80   const std::vector<PartialTensorShape>& output_shapes() const override {
81     return output_shapes_;
82   }
83 
DebugString() const84   string DebugString() const override {
85     return name_utils::DatasetDebugString(kDatasetType);
86   }
87 
CardinalityInternal() const88   int64_t CardinalityInternal() const override {
89     if (input_cardinality_ == kInfiniteCardinality ||
90         to_concatenate_cardinality_ == kInfiniteCardinality) {
91       return kInfiniteCardinality;
92     }
93     if (input_cardinality_ == kUnknownCardinality ||
94         to_concatenate_cardinality_ == kUnknownCardinality) {
95       return kUnknownCardinality;
96     }
97     return input_cardinality_ + to_concatenate_cardinality_;
98   }
99 
CardinalityInternal(CardinalityOptions options) const100   int64_t CardinalityInternal(CardinalityOptions options) const override {
101     int64_t input_cardinality = input_->Cardinality(options);
102     int64_t to_concatenate_cardinality = to_concatenate_->Cardinality(options);
103 
104     if (input_cardinality == kInfiniteCardinality ||
105         to_concatenate_cardinality == kInfiniteCardinality) {
106       return kInfiniteCardinality;
107     }
108     if (input_cardinality == kUnknownCardinality ||
109         to_concatenate_cardinality == kUnknownCardinality) {
110       return kUnknownCardinality;
111     }
112     return input_cardinality + to_concatenate_cardinality;
113   }
114 
InputDatasets(std::vector<const DatasetBase * > * inputs) const115   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
116     inputs->push_back(input_);
117     inputs->push_back(to_concatenate_);
118     return OkStatus();
119   }
120 
CheckExternalState() const121   Status CheckExternalState() const override {
122     TF_RETURN_IF_ERROR(input_->CheckExternalState());
123     return to_concatenate_->CheckExternalState();
124   }
125 
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const126   Status Get(OpKernelContext* ctx, int64 index,
127              std::vector<Tensor>* out_tensors) const override {
128     TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index));
129     if (index < input_cardinality_) {
130       TF_RETURN_IF_ERROR(input_->Get(ctx, index, out_tensors));
131     } else {
132       TF_RETURN_IF_ERROR(
133           to_concatenate_->Get(ctx, index - input_cardinality_, out_tensors));
134     }
135     return OkStatus();
136   }
137 
138  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const139   Status AsGraphDefInternal(SerializationContext* ctx,
140                             DatasetGraphDefBuilder* b,
141                             Node** output) const override {
142     Node* input_graph = nullptr;
143     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph));
144     Node* to_concatenate_graph = nullptr;
145     TF_RETURN_IF_ERROR(
146         b->AddInputDataset(ctx, to_concatenate_, &to_concatenate_graph));
147     TF_RETURN_IF_ERROR(
148         b->AddDataset(this, {input_graph, to_concatenate_graph}, output));
149     return OkStatus();
150   }
151 
152  private:
153   class Iterator : public DatasetIterator<Dataset> {
154    public:
Iterator(const Params & params)155     explicit Iterator(const Params& params)
156         : DatasetIterator<Dataset>(params), i_(0) {}
157 
Initialize(IteratorContext * ctx)158     Status Initialize(IteratorContext* ctx) override {
159       TF_ASSIGN_OR_RETURN(input_contexts_,
160                           CreateInputIteratorContexts(ctx, dataset()));
161       return dataset()->input_->MakeIterator(&input_contexts_[0], this,
162                                              strings::StrCat(prefix(), "[0]"),
163                                              &input_impl_);
164     }
165 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)166     Status GetNextInternal(IteratorContext* ctx,
167                            std::vector<Tensor>* out_tensors,
168                            bool* end_of_sequence) override {
169       mutex_lock l(mu_);
170       if (!input_impl_) {
171         *end_of_sequence = true;
172         return OkStatus();
173       }
174       while (i_ < 2) {
175         TF_RETURN_IF_ERROR(input_impl_->GetNext(&input_contexts_[i_],
176                                                 out_tensors, end_of_sequence));
177         if (!*end_of_sequence) {
178           return OkStatus();
179         }
180         if (++i_ < 2) {
181           TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator(
182               &input_contexts_[i_], this, strings::StrCat(prefix(), "[1]"),
183               &input_impl_));
184         }
185       }
186       *end_of_sequence = true;
187       input_impl_.reset();
188       return OkStatus();
189     }
190 
191    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const192     std::shared_ptr<model::Node> CreateNode(
193         IteratorContext* ctx, model::Node::Args args) const override {
194       return model::MakeKnownRatioNode(std::move(args),
195                                        /*ratio=*/1);
196     }
197 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)198     Status SaveInternal(SerializationContext* ctx,
199                         IteratorStateWriter* writer) override {
200       mutex_lock l(mu_);
201       TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), i_));
202       if (input_impl_) {
203         TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
204       } else {
205         TF_RETURN_IF_ERROR(
206             writer->WriteScalar(full_name(kInputImplUninitialized), ""));
207       }
208       return OkStatus();
209     }
210 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)211     Status RestoreInternal(IteratorContext* ctx,
212                            IteratorStateReader* reader) override {
213       mutex_lock l(mu_);
214       TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kIndex), &i_));
215       if (reader->Contains(full_name(kInputImplUninitialized))) {
216         input_impl_.reset();
217         return OkStatus();
218       }
219       if (!TF_PREDICT_TRUE(i_ >= 0 && i_ <= 2))
220         return errors::InvalidArgument("i_ must be in range [0, 2].");
221       if (i_ == 1) {
222         TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator(
223             ctx, this, strings::StrCat(prefix(), "[1]"), &input_impl_));
224       } else if (i_ == 2) {
225         input_impl_.reset();
226       }
227       if (input_impl_) {
228         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
229       }
230       return OkStatus();
231     }
232 
233    private:
234     mutex mu_;
235     int64_t i_ TF_GUARDED_BY(mu_);
236     std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
237     std::vector<IteratorContext> input_contexts_;
238   };
239 
MostSpecificCompatibleShape(const PartialTensorShape & ts1,const PartialTensorShape & ts2)240   static PartialTensorShape MostSpecificCompatibleShape(
241       const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
242     if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
243       return PartialTensorShape();
244     PartialTensorShape output_tensorshape({});
245     auto dims1 = ts1.dim_sizes();
246     auto dims2 = ts2.dim_sizes();
247     for (int d = 0; d < ts1.dims(); d++) {
248       if (dims1[d] == dims2[d])
249         output_tensorshape.AddDim(dims1[d]);
250       else
251         output_tensorshape.AddDim(-1);
252     }
253     return output_tensorshape;
254   }
255 
256   const DatasetBase* input_;
257   const DatasetBase* to_concatenate_;
258   const int64_t input_cardinality_;
259   const int64_t to_concatenate_cardinality_;
260   std::vector<PartialTensorShape> output_shapes_;
261 };
262 
ConcatenateDatasetOp(OpKernelConstruction * ctx)263 ConcatenateDatasetOp::ConcatenateDatasetOp(OpKernelConstruction* ctx)
264     : BinaryDatasetOpKernel(ctx) {}
265 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase * to_concatenate,DatasetBase ** output)266 void ConcatenateDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
267                                        DatasetBase* to_concatenate,
268                                        DatasetBase** output) {
269   OP_REQUIRES(ctx, input->output_dtypes() == to_concatenate->output_dtypes(),
270               errors::InvalidArgument(
271                   "input dataset and dataset to concatenate"
272                   " have different output_types %s and %s",
273                   (DataTypeVectorString(input->output_dtypes()),
274                    DataTypeVectorString(to_concatenate->output_dtypes()))));
275   *output = new Dataset(ctx, input, to_concatenate);
276 }
277 
278 namespace {
279 REGISTER_KERNEL_BUILDER(Name("ConcatenateDataset").Device(DEVICE_CPU),
280                         ConcatenateDatasetOp);
281 }  // namespace
282 }  // namespace data
283 }  // namespace tensorflow
284