1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h"
16 
17 #include <string>
18 #include <utility>
19 
20 #include "tensorflow/core/data/name_utils.h"
21 #include "tensorflow/core/data/split_utils.h"
22 #include "tensorflow/core/framework/partial_tensor_shape.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/lib/hash/hash.h"
25 
26 namespace tensorflow {
27 namespace data {
28 namespace experimental {
29 
30 /* static */ constexpr const char* const
31     DirectedInterleaveDatasetOp::kDatasetType;
32 /* static */ constexpr const char* const
33     DirectedInterleaveDatasetOp::kSelectorInputDataset;
34 /* static */ constexpr const char* const
35     DirectedInterleaveDatasetOp::kDataInputDatasets;
36 /* static */ constexpr const char* const
37     DirectedInterleaveDatasetOp::kStopOnEmptyDataset;
38 /* static */ constexpr const char* const
39     DirectedInterleaveDatasetOp::kOutputTypes;
40 /* static */ constexpr const char* const
41     DirectedInterleaveDatasetOp::kOutputShapes;
42 /* static */ constexpr const char* const
43     DirectedInterleaveDatasetOp::kNumInputDatasets;
44 
45 constexpr char kCycleLength[] = "cycle_length";
46 
47 class DirectedInterleaveDatasetOp::Dataset : public DatasetBase {
48  public:
Dataset(OpKernelContext * ctx,const DatasetBase * selector_input,std::vector<DatasetBase * > data_inputs,bool stop_on_empty_dataset)49   Dataset(OpKernelContext* ctx, const DatasetBase* selector_input,
50           std::vector<DatasetBase*> data_inputs, bool stop_on_empty_dataset)
51       : DatasetBase(DatasetContext(ctx)),
52         selector_input_(selector_input),
53         data_inputs_(std::move(data_inputs)),
54         stop_on_empty_dataset_(stop_on_empty_dataset) {
55     selector_input_->Ref();
56 
57     output_shapes_ = data_inputs_[0]->output_shapes();
58     data_inputs_[0]->Ref();
59     for (size_t i = 1; i < data_inputs_.size(); ++i) {
60       const DatasetBase* data_input = data_inputs_[i];
61       data_input->Ref();
62       for (size_t j = 0; j < output_shapes_.size(); ++j) {
63         output_shapes_[j] = MostSpecificCompatibleShape(
64             output_shapes_[j], data_input->output_shapes()[j]);
65       }
66     }
67   }
68 
~Dataset()69   ~Dataset() override {
70     selector_input_->Unref();
71     for (DatasetBase* data_input : data_inputs_) {
72       data_input->Unref();
73     }
74   }
75 
MakeIteratorInternal(const string & prefix) const76   std::unique_ptr<IteratorBase> MakeIteratorInternal(
77       const string& prefix) const override {
78     return std::make_unique<Iterator>(Iterator::Params{
79         this, name_utils::IteratorPrefix(kDatasetType, prefix)});
80   }
81 
MakeSplitProviders(std::vector<std::unique_ptr<SplitProvider>> * split_providers) const82   Status MakeSplitProviders(std::vector<std::unique_ptr<SplitProvider>>*
83                                 split_providers) const override {
84     TF_ASSIGN_OR_RETURN(*split_providers, GetSplitProviders(this));
85     return OkStatus();
86   }
87 
output_dtypes() const88   const DataTypeVector& output_dtypes() const override {
89     return data_inputs_[0]->output_dtypes();
90   }
91 
output_shapes() const92   const std::vector<PartialTensorShape>& output_shapes() const override {
93     return output_shapes_;
94   }
95 
DebugString() const96   string DebugString() const override {
97     return name_utils::DatasetDebugString(kDatasetType);
98   }
99 
CardinalityInternal() const100   int64_t CardinalityInternal() const override {
101     // As long as one of input dataset has infinite cardinality, the output
102     // cardinality is infinite.
103     for (const auto& input : data_inputs_) {
104       int64_t n = input->Cardinality();
105       if (n == kInfiniteCardinality) {
106         return n;
107       }
108     }
109     return kUnknownCardinality;
110   }
111 
InputDatasets(std::vector<const DatasetBase * > * inputs) const112   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
113     inputs->push_back(selector_input_);
114     for (const auto& data_input : data_inputs_) {
115       inputs->push_back(data_input);
116     }
117     return OkStatus();
118   }
119 
CheckExternalState() const120   Status CheckExternalState() const override {
121     for (const auto& input : data_inputs_) {
122       TF_RETURN_IF_ERROR(input->CheckExternalState());
123     }
124     return selector_input_->CheckExternalState();
125   }
126 
127  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const128   Status AsGraphDefInternal(SerializationContext* ctx,
129                             DatasetGraphDefBuilder* b,
130                             Node** output) const override {
131     Node* selector_input_node;
132     TF_RETURN_IF_ERROR(
133         b->AddInputDataset(ctx, selector_input_, &selector_input_node));
134     std::vector<Node*> data_input_nodes(data_inputs_.size());
135     for (size_t i = 0; i < data_inputs_.size(); ++i) {
136       TF_RETURN_IF_ERROR(
137           b->AddInputDataset(ctx, data_inputs_[i], &data_input_nodes[i]));
138     }
139 
140     // Attr: stop_on_empty_dataset
141     AttrValue stop_on_empty_dataset_attr;
142     b->BuildAttrValue(stop_on_empty_dataset_, &stop_on_empty_dataset_attr);
143 
144     TF_RETURN_IF_ERROR(b->AddDataset(
145         this,
146         /*inputs=*/{{0, selector_input_node}},
147         /*list_inputs=*/{{1, data_input_nodes}},
148         /*attrs=*/
149         {std::make_pair(kStopOnEmptyDataset, stop_on_empty_dataset_attr)},
150         output));
151     return OkStatus();
152   }
153 
154  private:
155   class Iterator : public DatasetIterator<Dataset> {
156    public:
Iterator(const Params & params)157     explicit Iterator(const Params& params)
158         : DatasetIterator<Dataset>(params),
159           num_active_inputs_(params.dataset->data_inputs_.size()) {}
160 
Initialize(IteratorContext * ctx)161     Status Initialize(IteratorContext* ctx) override {
162       mutex_lock l(mu_);
163       TF_ASSIGN_OR_RETURN(input_contexts_,
164                           CreateInputIteratorContexts(ctx, dataset()));
165       TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator(
166           &input_contexts_[0], this, prefix(), &selector_input_impl_));
167       data_input_impls_.resize(dataset()->data_inputs_.size());
168       for (size_t i = 0; i < data_input_impls_.size(); ++i) {
169         const DatasetBase* data_input = dataset()->data_inputs_[i];
170         TF_RETURN_IF_ERROR(data_input->MakeIterator(
171             &input_contexts_[i + 1], this,
172             strings::StrCat(prefix(), "[", i, "]"), &data_input_impls_[i]));
173       }
174       return OkStatus();
175     }
176 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)177     Status GetNextInternal(IteratorContext* ctx,
178                            std::vector<Tensor>* out_tensors,
179                            bool* end_of_sequence) override {
180       mutex_lock l(mu_);
181       if (!selector_input_impl_) {
182         *end_of_sequence = true;
183         return OkStatus();
184       }
185 
186       while (true) {
187         std::vector<Tensor> selector_result;
188         *end_of_sequence = false;
189         TF_RETURN_IF_ERROR(selector_input_impl_->GetNext(
190             &input_contexts_[0], &selector_result, end_of_sequence));
191         if (*end_of_sequence) {
192           ResetInputs();
193           return OkStatus();
194         }
195 
196         int64_t selected_input = selector_result[0].scalar<int64_t>()();
197         if (selected_input < 0 || selected_input >= data_input_impls_.size()) {
198           return errors::InvalidArgument(
199               "Selector index out of range: ", selected_input,
200               " >= ", data_input_impls_.size());
201         }
202 
203         if (data_input_impls_[selected_input]) {
204           bool end_of_selected_input = false;
205           TF_RETURN_IF_ERROR(data_input_impls_[selected_input]->GetNext(
206               &input_contexts_[selected_input + 1], out_tensors,
207               &end_of_selected_input));
208 
209           if (!end_of_selected_input) {
210             return OkStatus();
211           }
212 
213           if (dataset()->stop_on_empty_dataset_) {
214             *end_of_sequence = true;
215             ResetInputs();
216             return OkStatus();
217           }
218 
219           data_input_impls_[selected_input].reset();
220           --num_active_inputs_;
221 
222           if (num_active_inputs_ == 0) {
223             selector_input_impl_.reset();
224             *end_of_sequence = true;
225             return OkStatus();
226           }
227         }
228 
229         VLOG(2) << "DirectedInterleave selected an exhausted input: "
230                 << selected_input;
231       }
232     }
233 
234    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const235     std::shared_ptr<model::Node> CreateNode(
236         IteratorContext* ctx, model::Node::Args args) const override {
237       return model::MakeInterleaveManyNode(
238           std::move(args),
239           {model::MakeNonTunableParameter(kCycleLength, /*value=*/1)});
240     }
241 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)242     Status SaveInternal(SerializationContext* ctx,
243                         IteratorStateWriter* writer) override {
244       mutex_lock l(mu_);
245       if (selector_input_impl_) {
246         TF_RETURN_IF_ERROR(SaveInput(ctx, writer, selector_input_impl_));
247       } else {
248         TF_RETURN_IF_ERROR(
249             writer->WriteScalar(full_name("selector_input_impl_empty"), ""));
250       }
251       for (size_t i = 0; i < data_input_impls_.size(); ++i) {
252         const auto& data_input_impl = data_input_impls_[i];
253         if (data_input_impl) {
254           TF_RETURN_IF_ERROR(SaveInput(ctx, writer, data_input_impl));
255         } else {
256           TF_RETURN_IF_ERROR(writer->WriteScalar(
257               full_name(strings::StrCat("data_input_impl_empty[", i, "]")),
258               ""));
259         }
260       }
261       return OkStatus();
262     }
263 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)264     Status RestoreInternal(IteratorContext* ctx,
265                            IteratorStateReader* reader) override {
266       mutex_lock l(mu_);
267       if (!reader->Contains(full_name("selector_input_impl_empty"))) {
268         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_));
269       } else {
270         selector_input_impl_.reset();
271       }
272       for (size_t i = 0; i < data_input_impls_.size(); ++i) {
273         if (!reader->Contains(
274                 full_name(strings::StrCat("data_input_impl_empty[", i, "]")))) {
275           TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i]));
276         } else {
277           data_input_impls_[i].reset();
278         }
279       }
280       return OkStatus();
281     }
282 
283    private:
ResetInputs()284     void ResetInputs() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
285       selector_input_impl_.reset();
286       for (auto& data_input_impl : data_input_impls_) {
287         data_input_impl.reset();
288       }
289       num_active_inputs_ = 0;
290     }
291 
292     mutex mu_;
293     // Iterator contexts for inputs datasets. The first context is for the
294     // selector input, and the remaning contexts are for the data inputs.
295     std::vector<IteratorContext> input_contexts_;
296     std::unique_ptr<IteratorBase> selector_input_impl_ TF_GUARDED_BY(mu_);
297     std::vector<std::unique_ptr<IteratorBase>> data_input_impls_
298         TF_GUARDED_BY(mu_);
299     int64_t num_active_inputs_ TF_GUARDED_BY(mu_);
300   };
301 
MostSpecificCompatibleShape(const PartialTensorShape & ts1,const PartialTensorShape & ts2)302   static PartialTensorShape MostSpecificCompatibleShape(
303       const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
304     PartialTensorShape output_tensorshape;
305     if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
306       return output_tensorshape;
307     auto dims1 = ts1.dim_sizes();
308     auto dims2 = ts2.dim_sizes();
309     for (int d = 0; d < ts1.dims(); ++d) {
310       if (dims1[d] == dims2[d])
311         output_tensorshape.Concatenate(dims1[d]);
312       else
313         output_tensorshape.Concatenate(-1);
314     }
315     return output_tensorshape;
316   }
317 
318   const DatasetBase* const selector_input_;
319   const std::vector<DatasetBase*> data_inputs_;
320   std::vector<PartialTensorShape> output_shapes_;
321   const bool stop_on_empty_dataset_;
322 };
323 
DirectedInterleaveDatasetOp(OpKernelConstruction * ctx)324 DirectedInterleaveDatasetOp::DirectedInterleaveDatasetOp(
325     OpKernelConstruction* ctx)
326     : DatasetOpKernel(ctx) {
327   if (ctx->HasAttr(kStopOnEmptyDataset)) {
328     OP_REQUIRES_OK(ctx,
329                    ctx->GetAttr(kStopOnEmptyDataset, &stop_on_empty_dataset_));
330   }
331 }
332 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)333 void DirectedInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
334                                               DatasetBase** output) {
335   DatasetBase* selector_input;
336   OP_REQUIRES_OK(ctx,
337                  GetDatasetFromVariantTensor(ctx->input(0), &selector_input));
338 
339   OP_REQUIRES(
340       ctx,
341       selector_input->output_dtypes().size() == 1 &&
342           selector_input->output_dtypes()[0] == DT_INT64 &&
343           selector_input->output_shapes().size() == 1 &&
344           selector_input->output_shapes()[0].IsCompatibleWith(
345               PartialTensorShape({})),
346       errors::InvalidArgument(
347           "The selector input must be a dataset of scalar int64 elements."));
348 
349   // The first input is the selector, followed by dataset inputs.
350   std::vector<DatasetBase*> data_inputs;
351   for (size_t i = 1; i < ctx->num_inputs(); ++i) {
352     DatasetBase* input;
353     OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input));
354     data_inputs.push_back(input);
355 
356     OP_REQUIRES(ctx, data_inputs[0]->output_dtypes() == input->output_dtypes(),
357                 errors::InvalidArgument(
358                     "All inputs must have the same output_dtypes. First input "
359                     "has types ",
360                     DataTypeVectorString(data_inputs[0]->output_dtypes()),
361                     ", and input ", i - 1, " has types ",
362                     DataTypeVectorString(input->output_dtypes())));
363   }
364 
365   *output = new Dataset(ctx, selector_input, std::move(data_inputs),
366                         stop_on_empty_dataset_);
367 }
368 
369 namespace {
370 REGISTER_KERNEL_BUILDER(Name("DirectedInterleaveDataset").Device(DEVICE_CPU),
371                         DirectedInterleaveDatasetOp);
372 REGISTER_KERNEL_BUILDER(
373     Name("ExperimentalDirectedInterleaveDataset").Device(DEVICE_CPU),
374     DirectedInterleaveDatasetOp);
375 }  // namespace
376 }  // namespace experimental
377 }  // namespace data
378 }  // namespace tensorflow
379