1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h"
16
17 #include <string>
18 #include <utility>
19
20 #include "tensorflow/core/data/name_utils.h"
21 #include "tensorflow/core/data/split_utils.h"
22 #include "tensorflow/core/framework/partial_tensor_shape.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/lib/hash/hash.h"
25
26 namespace tensorflow {
27 namespace data {
28 namespace experimental {
29
30 /* static */ constexpr const char* const
31 DirectedInterleaveDatasetOp::kDatasetType;
32 /* static */ constexpr const char* const
33 DirectedInterleaveDatasetOp::kSelectorInputDataset;
34 /* static */ constexpr const char* const
35 DirectedInterleaveDatasetOp::kDataInputDatasets;
36 /* static */ constexpr const char* const
37 DirectedInterleaveDatasetOp::kStopOnEmptyDataset;
38 /* static */ constexpr const char* const
39 DirectedInterleaveDatasetOp::kOutputTypes;
40 /* static */ constexpr const char* const
41 DirectedInterleaveDatasetOp::kOutputShapes;
42 /* static */ constexpr const char* const
43 DirectedInterleaveDatasetOp::kNumInputDatasets;
44
45 constexpr char kCycleLength[] = "cycle_length";
46
47 class DirectedInterleaveDatasetOp::Dataset : public DatasetBase {
48 public:
Dataset(OpKernelContext * ctx,const DatasetBase * selector_input,std::vector<DatasetBase * > data_inputs,bool stop_on_empty_dataset)49 Dataset(OpKernelContext* ctx, const DatasetBase* selector_input,
50 std::vector<DatasetBase*> data_inputs, bool stop_on_empty_dataset)
51 : DatasetBase(DatasetContext(ctx)),
52 selector_input_(selector_input),
53 data_inputs_(std::move(data_inputs)),
54 stop_on_empty_dataset_(stop_on_empty_dataset) {
55 selector_input_->Ref();
56
57 output_shapes_ = data_inputs_[0]->output_shapes();
58 data_inputs_[0]->Ref();
59 for (size_t i = 1; i < data_inputs_.size(); ++i) {
60 const DatasetBase* data_input = data_inputs_[i];
61 data_input->Ref();
62 for (size_t j = 0; j < output_shapes_.size(); ++j) {
63 output_shapes_[j] = MostSpecificCompatibleShape(
64 output_shapes_[j], data_input->output_shapes()[j]);
65 }
66 }
67 }
68
~Dataset()69 ~Dataset() override {
70 selector_input_->Unref();
71 for (DatasetBase* data_input : data_inputs_) {
72 data_input->Unref();
73 }
74 }
75
MakeIteratorInternal(const string & prefix) const76 std::unique_ptr<IteratorBase> MakeIteratorInternal(
77 const string& prefix) const override {
78 return std::make_unique<Iterator>(Iterator::Params{
79 this, name_utils::IteratorPrefix(kDatasetType, prefix)});
80 }
81
MakeSplitProviders(std::vector<std::unique_ptr<SplitProvider>> * split_providers) const82 Status MakeSplitProviders(std::vector<std::unique_ptr<SplitProvider>>*
83 split_providers) const override {
84 TF_ASSIGN_OR_RETURN(*split_providers, GetSplitProviders(this));
85 return OkStatus();
86 }
87
output_dtypes() const88 const DataTypeVector& output_dtypes() const override {
89 return data_inputs_[0]->output_dtypes();
90 }
91
output_shapes() const92 const std::vector<PartialTensorShape>& output_shapes() const override {
93 return output_shapes_;
94 }
95
DebugString() const96 string DebugString() const override {
97 return name_utils::DatasetDebugString(kDatasetType);
98 }
99
CardinalityInternal() const100 int64_t CardinalityInternal() const override {
101 // As long as one of input dataset has infinite cardinality, the output
102 // cardinality is infinite.
103 for (const auto& input : data_inputs_) {
104 int64_t n = input->Cardinality();
105 if (n == kInfiniteCardinality) {
106 return n;
107 }
108 }
109 return kUnknownCardinality;
110 }
111
InputDatasets(std::vector<const DatasetBase * > * inputs) const112 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
113 inputs->push_back(selector_input_);
114 for (const auto& data_input : data_inputs_) {
115 inputs->push_back(data_input);
116 }
117 return OkStatus();
118 }
119
CheckExternalState() const120 Status CheckExternalState() const override {
121 for (const auto& input : data_inputs_) {
122 TF_RETURN_IF_ERROR(input->CheckExternalState());
123 }
124 return selector_input_->CheckExternalState();
125 }
126
127 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const128 Status AsGraphDefInternal(SerializationContext* ctx,
129 DatasetGraphDefBuilder* b,
130 Node** output) const override {
131 Node* selector_input_node;
132 TF_RETURN_IF_ERROR(
133 b->AddInputDataset(ctx, selector_input_, &selector_input_node));
134 std::vector<Node*> data_input_nodes(data_inputs_.size());
135 for (size_t i = 0; i < data_inputs_.size(); ++i) {
136 TF_RETURN_IF_ERROR(
137 b->AddInputDataset(ctx, data_inputs_[i], &data_input_nodes[i]));
138 }
139
140 // Attr: stop_on_empty_dataset
141 AttrValue stop_on_empty_dataset_attr;
142 b->BuildAttrValue(stop_on_empty_dataset_, &stop_on_empty_dataset_attr);
143
144 TF_RETURN_IF_ERROR(b->AddDataset(
145 this,
146 /*inputs=*/{{0, selector_input_node}},
147 /*list_inputs=*/{{1, data_input_nodes}},
148 /*attrs=*/
149 {std::make_pair(kStopOnEmptyDataset, stop_on_empty_dataset_attr)},
150 output));
151 return OkStatus();
152 }
153
154 private:
155 class Iterator : public DatasetIterator<Dataset> {
156 public:
Iterator(const Params & params)157 explicit Iterator(const Params& params)
158 : DatasetIterator<Dataset>(params),
159 num_active_inputs_(params.dataset->data_inputs_.size()) {}
160
Initialize(IteratorContext * ctx)161 Status Initialize(IteratorContext* ctx) override {
162 mutex_lock l(mu_);
163 TF_ASSIGN_OR_RETURN(input_contexts_,
164 CreateInputIteratorContexts(ctx, dataset()));
165 TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator(
166 &input_contexts_[0], this, prefix(), &selector_input_impl_));
167 data_input_impls_.resize(dataset()->data_inputs_.size());
168 for (size_t i = 0; i < data_input_impls_.size(); ++i) {
169 const DatasetBase* data_input = dataset()->data_inputs_[i];
170 TF_RETURN_IF_ERROR(data_input->MakeIterator(
171 &input_contexts_[i + 1], this,
172 strings::StrCat(prefix(), "[", i, "]"), &data_input_impls_[i]));
173 }
174 return OkStatus();
175 }
176
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)177 Status GetNextInternal(IteratorContext* ctx,
178 std::vector<Tensor>* out_tensors,
179 bool* end_of_sequence) override {
180 mutex_lock l(mu_);
181 if (!selector_input_impl_) {
182 *end_of_sequence = true;
183 return OkStatus();
184 }
185
186 while (true) {
187 std::vector<Tensor> selector_result;
188 *end_of_sequence = false;
189 TF_RETURN_IF_ERROR(selector_input_impl_->GetNext(
190 &input_contexts_[0], &selector_result, end_of_sequence));
191 if (*end_of_sequence) {
192 ResetInputs();
193 return OkStatus();
194 }
195
196 int64_t selected_input = selector_result[0].scalar<int64_t>()();
197 if (selected_input < 0 || selected_input >= data_input_impls_.size()) {
198 return errors::InvalidArgument(
199 "Selector index out of range: ", selected_input,
200 " >= ", data_input_impls_.size());
201 }
202
203 if (data_input_impls_[selected_input]) {
204 bool end_of_selected_input = false;
205 TF_RETURN_IF_ERROR(data_input_impls_[selected_input]->GetNext(
206 &input_contexts_[selected_input + 1], out_tensors,
207 &end_of_selected_input));
208
209 if (!end_of_selected_input) {
210 return OkStatus();
211 }
212
213 if (dataset()->stop_on_empty_dataset_) {
214 *end_of_sequence = true;
215 ResetInputs();
216 return OkStatus();
217 }
218
219 data_input_impls_[selected_input].reset();
220 --num_active_inputs_;
221
222 if (num_active_inputs_ == 0) {
223 selector_input_impl_.reset();
224 *end_of_sequence = true;
225 return OkStatus();
226 }
227 }
228
229 VLOG(2) << "DirectedInterleave selected an exhausted input: "
230 << selected_input;
231 }
232 }
233
234 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const235 std::shared_ptr<model::Node> CreateNode(
236 IteratorContext* ctx, model::Node::Args args) const override {
237 return model::MakeInterleaveManyNode(
238 std::move(args),
239 {model::MakeNonTunableParameter(kCycleLength, /*value=*/1)});
240 }
241
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)242 Status SaveInternal(SerializationContext* ctx,
243 IteratorStateWriter* writer) override {
244 mutex_lock l(mu_);
245 if (selector_input_impl_) {
246 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, selector_input_impl_));
247 } else {
248 TF_RETURN_IF_ERROR(
249 writer->WriteScalar(full_name("selector_input_impl_empty"), ""));
250 }
251 for (size_t i = 0; i < data_input_impls_.size(); ++i) {
252 const auto& data_input_impl = data_input_impls_[i];
253 if (data_input_impl) {
254 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, data_input_impl));
255 } else {
256 TF_RETURN_IF_ERROR(writer->WriteScalar(
257 full_name(strings::StrCat("data_input_impl_empty[", i, "]")),
258 ""));
259 }
260 }
261 return OkStatus();
262 }
263
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)264 Status RestoreInternal(IteratorContext* ctx,
265 IteratorStateReader* reader) override {
266 mutex_lock l(mu_);
267 if (!reader->Contains(full_name("selector_input_impl_empty"))) {
268 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_));
269 } else {
270 selector_input_impl_.reset();
271 }
272 for (size_t i = 0; i < data_input_impls_.size(); ++i) {
273 if (!reader->Contains(
274 full_name(strings::StrCat("data_input_impl_empty[", i, "]")))) {
275 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i]));
276 } else {
277 data_input_impls_[i].reset();
278 }
279 }
280 return OkStatus();
281 }
282
283 private:
ResetInputs()284 void ResetInputs() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
285 selector_input_impl_.reset();
286 for (auto& data_input_impl : data_input_impls_) {
287 data_input_impl.reset();
288 }
289 num_active_inputs_ = 0;
290 }
291
292 mutex mu_;
293 // Iterator contexts for inputs datasets. The first context is for the
294 // selector input, and the remaning contexts are for the data inputs.
295 std::vector<IteratorContext> input_contexts_;
296 std::unique_ptr<IteratorBase> selector_input_impl_ TF_GUARDED_BY(mu_);
297 std::vector<std::unique_ptr<IteratorBase>> data_input_impls_
298 TF_GUARDED_BY(mu_);
299 int64_t num_active_inputs_ TF_GUARDED_BY(mu_);
300 };
301
MostSpecificCompatibleShape(const PartialTensorShape & ts1,const PartialTensorShape & ts2)302 static PartialTensorShape MostSpecificCompatibleShape(
303 const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
304 PartialTensorShape output_tensorshape;
305 if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
306 return output_tensorshape;
307 auto dims1 = ts1.dim_sizes();
308 auto dims2 = ts2.dim_sizes();
309 for (int d = 0; d < ts1.dims(); ++d) {
310 if (dims1[d] == dims2[d])
311 output_tensorshape.Concatenate(dims1[d]);
312 else
313 output_tensorshape.Concatenate(-1);
314 }
315 return output_tensorshape;
316 }
317
318 const DatasetBase* const selector_input_;
319 const std::vector<DatasetBase*> data_inputs_;
320 std::vector<PartialTensorShape> output_shapes_;
321 const bool stop_on_empty_dataset_;
322 };
323
DirectedInterleaveDatasetOp(OpKernelConstruction * ctx)324 DirectedInterleaveDatasetOp::DirectedInterleaveDatasetOp(
325 OpKernelConstruction* ctx)
326 : DatasetOpKernel(ctx) {
327 if (ctx->HasAttr(kStopOnEmptyDataset)) {
328 OP_REQUIRES_OK(ctx,
329 ctx->GetAttr(kStopOnEmptyDataset, &stop_on_empty_dataset_));
330 }
331 }
332
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)333 void DirectedInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
334 DatasetBase** output) {
335 DatasetBase* selector_input;
336 OP_REQUIRES_OK(ctx,
337 GetDatasetFromVariantTensor(ctx->input(0), &selector_input));
338
339 OP_REQUIRES(
340 ctx,
341 selector_input->output_dtypes().size() == 1 &&
342 selector_input->output_dtypes()[0] == DT_INT64 &&
343 selector_input->output_shapes().size() == 1 &&
344 selector_input->output_shapes()[0].IsCompatibleWith(
345 PartialTensorShape({})),
346 errors::InvalidArgument(
347 "The selector input must be a dataset of scalar int64 elements."));
348
349 // The first input is the selector, followed by dataset inputs.
350 std::vector<DatasetBase*> data_inputs;
351 for (size_t i = 1; i < ctx->num_inputs(); ++i) {
352 DatasetBase* input;
353 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input));
354 data_inputs.push_back(input);
355
356 OP_REQUIRES(ctx, data_inputs[0]->output_dtypes() == input->output_dtypes(),
357 errors::InvalidArgument(
358 "All inputs must have the same output_dtypes. First input "
359 "has types ",
360 DataTypeVectorString(data_inputs[0]->output_dtypes()),
361 ", and input ", i - 1, " has types ",
362 DataTypeVectorString(input->output_dtypes())));
363 }
364
365 *output = new Dataset(ctx, selector_input, std::move(data_inputs),
366 stop_on_empty_dataset_);
367 }
368
369 namespace {
370 REGISTER_KERNEL_BUILDER(Name("DirectedInterleaveDataset").Device(DEVICE_CPU),
371 DirectedInterleaveDatasetOp);
372 REGISTER_KERNEL_BUILDER(
373 Name("ExperimentalDirectedInterleaveDataset").Device(DEVICE_CPU),
374 DirectedInterleaveDatasetOp);
375 } // namespace
376 } // namespace experimental
377 } // namespace data
378 } // namespace tensorflow
379