1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/concatenate_dataset_op.h"
16
17 #include <string>
18 #include <utility>
19
20 #include "tensorflow/core/data/name_utils.h"
21 #include "tensorflow/core/data/split_utils.h"
22 #include "tensorflow/core/framework/partial_tensor_shape.h"
23 #include "tensorflow/core/framework/tensor.h"
24
25 namespace tensorflow {
26 namespace data {
27
28 // See documentation in ../../ops/dataset_ops.cc for a high-level
29 // description of the following op.
30
31 /* static */ constexpr const char* const ConcatenateDatasetOp::kDatasetType;
32 /* static */ constexpr const char* const ConcatenateDatasetOp::kInputDataset;
33 /* static */ constexpr const char* const ConcatenateDatasetOp::kAnotherDataset;
34 /* static */ constexpr const char* const ConcatenateDatasetOp::kOutputTypes;
35 /* static */ constexpr const char* const ConcatenateDatasetOp::kOutputShapes;
36
37 constexpr char kIndex[] = "i";
38 constexpr char kInputImplUninitialized[] = "input_impl_uninitialized";
39
40 class ConcatenateDatasetOp::Dataset : public DatasetBase {
41 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,const DatasetBase * to_concatenate)42 explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
43 const DatasetBase* to_concatenate)
44 : DatasetBase(DatasetContext(ctx)),
45 input_(input),
46 to_concatenate_(to_concatenate),
47 input_cardinality_(input->Cardinality()),
48 to_concatenate_cardinality_(to_concatenate_->Cardinality()) {
49 input_->Ref();
50 to_concatenate_->Ref();
51
52 auto os_input = input->output_shapes();
53 auto os_concatenate = to_concatenate->output_shapes();
54 for (int i = 0; i < os_input.size(); i++) {
55 output_shapes_.push_back(
56 MostSpecificCompatibleShape(os_input[i], os_concatenate[i]));
57 }
58 }
~Dataset()59 ~Dataset() override {
60 input_->Unref();
61 to_concatenate_->Unref();
62 }
63
MakeIteratorInternal(const string & prefix) const64 std::unique_ptr<IteratorBase> MakeIteratorInternal(
65 const string& prefix) const override {
66 return std::make_unique<Iterator>(Iterator::Params{
67 this, name_utils::IteratorPrefix(kDatasetType, prefix)});
68 }
69
MakeSplitProviders(std::vector<std::unique_ptr<SplitProvider>> * split_providers) const70 Status MakeSplitProviders(std::vector<std::unique_ptr<SplitProvider>>*
71 split_providers) const override {
72 TF_ASSIGN_OR_RETURN(*split_providers, GetSplitProviders(this));
73 return OkStatus();
74 }
75
output_dtypes() const76 const DataTypeVector& output_dtypes() const override {
77 return input_->output_dtypes();
78 }
79
output_shapes() const80 const std::vector<PartialTensorShape>& output_shapes() const override {
81 return output_shapes_;
82 }
83
DebugString() const84 string DebugString() const override {
85 return name_utils::DatasetDebugString(kDatasetType);
86 }
87
CardinalityInternal() const88 int64_t CardinalityInternal() const override {
89 if (input_cardinality_ == kInfiniteCardinality ||
90 to_concatenate_cardinality_ == kInfiniteCardinality) {
91 return kInfiniteCardinality;
92 }
93 if (input_cardinality_ == kUnknownCardinality ||
94 to_concatenate_cardinality_ == kUnknownCardinality) {
95 return kUnknownCardinality;
96 }
97 return input_cardinality_ + to_concatenate_cardinality_;
98 }
99
CardinalityInternal(CardinalityOptions options) const100 int64_t CardinalityInternal(CardinalityOptions options) const override {
101 int64_t input_cardinality = input_->Cardinality(options);
102 int64_t to_concatenate_cardinality = to_concatenate_->Cardinality(options);
103
104 if (input_cardinality == kInfiniteCardinality ||
105 to_concatenate_cardinality == kInfiniteCardinality) {
106 return kInfiniteCardinality;
107 }
108 if (input_cardinality == kUnknownCardinality ||
109 to_concatenate_cardinality == kUnknownCardinality) {
110 return kUnknownCardinality;
111 }
112 return input_cardinality + to_concatenate_cardinality;
113 }
114
InputDatasets(std::vector<const DatasetBase * > * inputs) const115 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
116 inputs->push_back(input_);
117 inputs->push_back(to_concatenate_);
118 return OkStatus();
119 }
120
CheckExternalState() const121 Status CheckExternalState() const override {
122 TF_RETURN_IF_ERROR(input_->CheckExternalState());
123 return to_concatenate_->CheckExternalState();
124 }
125
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const126 Status Get(OpKernelContext* ctx, int64 index,
127 std::vector<Tensor>* out_tensors) const override {
128 TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index));
129 if (index < input_cardinality_) {
130 TF_RETURN_IF_ERROR(input_->Get(ctx, index, out_tensors));
131 } else {
132 TF_RETURN_IF_ERROR(
133 to_concatenate_->Get(ctx, index - input_cardinality_, out_tensors));
134 }
135 return OkStatus();
136 }
137
138 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const139 Status AsGraphDefInternal(SerializationContext* ctx,
140 DatasetGraphDefBuilder* b,
141 Node** output) const override {
142 Node* input_graph = nullptr;
143 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph));
144 Node* to_concatenate_graph = nullptr;
145 TF_RETURN_IF_ERROR(
146 b->AddInputDataset(ctx, to_concatenate_, &to_concatenate_graph));
147 TF_RETURN_IF_ERROR(
148 b->AddDataset(this, {input_graph, to_concatenate_graph}, output));
149 return OkStatus();
150 }
151
152 private:
153 class Iterator : public DatasetIterator<Dataset> {
154 public:
Iterator(const Params & params)155 explicit Iterator(const Params& params)
156 : DatasetIterator<Dataset>(params), i_(0) {}
157
Initialize(IteratorContext * ctx)158 Status Initialize(IteratorContext* ctx) override {
159 TF_ASSIGN_OR_RETURN(input_contexts_,
160 CreateInputIteratorContexts(ctx, dataset()));
161 return dataset()->input_->MakeIterator(&input_contexts_[0], this,
162 strings::StrCat(prefix(), "[0]"),
163 &input_impl_);
164 }
165
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)166 Status GetNextInternal(IteratorContext* ctx,
167 std::vector<Tensor>* out_tensors,
168 bool* end_of_sequence) override {
169 mutex_lock l(mu_);
170 if (!input_impl_) {
171 *end_of_sequence = true;
172 return OkStatus();
173 }
174 while (i_ < 2) {
175 TF_RETURN_IF_ERROR(input_impl_->GetNext(&input_contexts_[i_],
176 out_tensors, end_of_sequence));
177 if (!*end_of_sequence) {
178 return OkStatus();
179 }
180 if (++i_ < 2) {
181 TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator(
182 &input_contexts_[i_], this, strings::StrCat(prefix(), "[1]"),
183 &input_impl_));
184 }
185 }
186 *end_of_sequence = true;
187 input_impl_.reset();
188 return OkStatus();
189 }
190
191 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const192 std::shared_ptr<model::Node> CreateNode(
193 IteratorContext* ctx, model::Node::Args args) const override {
194 return model::MakeKnownRatioNode(std::move(args),
195 /*ratio=*/1);
196 }
197
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)198 Status SaveInternal(SerializationContext* ctx,
199 IteratorStateWriter* writer) override {
200 mutex_lock l(mu_);
201 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), i_));
202 if (input_impl_) {
203 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
204 } else {
205 TF_RETURN_IF_ERROR(
206 writer->WriteScalar(full_name(kInputImplUninitialized), ""));
207 }
208 return OkStatus();
209 }
210
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)211 Status RestoreInternal(IteratorContext* ctx,
212 IteratorStateReader* reader) override {
213 mutex_lock l(mu_);
214 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kIndex), &i_));
215 if (reader->Contains(full_name(kInputImplUninitialized))) {
216 input_impl_.reset();
217 return OkStatus();
218 }
219 if (!TF_PREDICT_TRUE(i_ >= 0 && i_ <= 2))
220 return errors::InvalidArgument("i_ must be in range [0, 2].");
221 if (i_ == 1) {
222 TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator(
223 ctx, this, strings::StrCat(prefix(), "[1]"), &input_impl_));
224 } else if (i_ == 2) {
225 input_impl_.reset();
226 }
227 if (input_impl_) {
228 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
229 }
230 return OkStatus();
231 }
232
233 private:
234 mutex mu_;
235 int64_t i_ TF_GUARDED_BY(mu_);
236 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
237 std::vector<IteratorContext> input_contexts_;
238 };
239
MostSpecificCompatibleShape(const PartialTensorShape & ts1,const PartialTensorShape & ts2)240 static PartialTensorShape MostSpecificCompatibleShape(
241 const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
242 if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
243 return PartialTensorShape();
244 PartialTensorShape output_tensorshape({});
245 auto dims1 = ts1.dim_sizes();
246 auto dims2 = ts2.dim_sizes();
247 for (int d = 0; d < ts1.dims(); d++) {
248 if (dims1[d] == dims2[d])
249 output_tensorshape.AddDim(dims1[d]);
250 else
251 output_tensorshape.AddDim(-1);
252 }
253 return output_tensorshape;
254 }
255
256 const DatasetBase* input_;
257 const DatasetBase* to_concatenate_;
258 const int64_t input_cardinality_;
259 const int64_t to_concatenate_cardinality_;
260 std::vector<PartialTensorShape> output_shapes_;
261 };
262
ConcatenateDatasetOp(OpKernelConstruction * ctx)263 ConcatenateDatasetOp::ConcatenateDatasetOp(OpKernelConstruction* ctx)
264 : BinaryDatasetOpKernel(ctx) {}
265
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase * to_concatenate,DatasetBase ** output)266 void ConcatenateDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
267 DatasetBase* to_concatenate,
268 DatasetBase** output) {
269 OP_REQUIRES(ctx, input->output_dtypes() == to_concatenate->output_dtypes(),
270 errors::InvalidArgument(
271 "input dataset and dataset to concatenate"
272 " have different output_types %s and %s",
273 (DataTypeVectorString(input->output_dtypes()),
274 DataTypeVectorString(to_concatenate->output_dtypes()))));
275 *output = new Dataset(ctx, input, to_concatenate);
276 }
277
278 namespace {
279 REGISTER_KERNEL_BUILDER(Name("ConcatenateDataset").Device(DEVICE_CPU),
280 ConcatenateDatasetOp);
281 } // namespace
282 } // namespace data
283 } // namespace tensorflow
284