1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/options_dataset_op.h"
16
17 #include "absl/memory/memory.h"
18 #include "tensorflow/core/data/name_utils.h"
19 #include "tensorflow/core/framework/dataset.h"
20 #include "tensorflow/core/framework/dataset_options.pb.h"
21 #include "tensorflow/core/framework/partial_tensor_shape.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/profiler/lib/traceme.h"
24
25 namespace tensorflow {
26 namespace data {
27
28 /* static */ constexpr const char* const OptionsDatasetOp::kDatasetType;
29 /* static */ constexpr const char* const OptionsDatasetOp::kInputDataset;
30 /* static */ constexpr const char* const OptionsDatasetOp::kOutputTypes;
31 /* static */ constexpr const char* const OptionsDatasetOp::kOutputShapes;
32 /* static */ constexpr const char* const OptionsDatasetOp::kSerializedOptions;
33
34 class OptionsDatasetOp::Dataset : public DatasetBase {
35 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,const string & serialized_options)36 Dataset(OpKernelContext* ctx, const DatasetBase* input,
37 const string& serialized_options)
38 : DatasetBase(DatasetContext(ctx)),
39 input_(input),
40 serialized_options_(serialized_options) {
41 input_->Ref();
42 Options options;
43 OP_REQUIRES(ctx, options.ParseFromString(serialized_options),
44 errors::InvalidArgument(absl::StrCat(
45 "Could not parse ", OptionsDatasetOp::kSerializedOptions,
46 " as valid Options.")));
47 set_options(options);
48 }
49
~Dataset()50 ~Dataset() override { input_->Unref(); }
51
MakeIteratorInternal(const string & prefix) const52 std::unique_ptr<IteratorBase> MakeIteratorInternal(
53 const string& prefix) const override {
54 DCHECK(false) << "OptionsDatasetOp::Dataset::MakeIteratorInternal is not "
55 "expected to be called because it is supposed to forward "
56 "the iterator to its input dataset(s).";
57 LOG(ERROR) << "Datasets of type " << type_string()
58 << " forwards its iterator to its input dataset. "
59 "`MakeIteratorInternal` is not implemented.";
60 return nullptr;
61 }
62
output_dtypes() const63 const DataTypeVector& output_dtypes() const override {
64 return input_->output_dtypes();
65 }
output_shapes() const66 const std::vector<PartialTensorShape>& output_shapes() const override {
67 return input_->output_shapes();
68 }
69
CardinalityInternal() const70 int64_t CardinalityInternal() const override { return input_->Cardinality(); }
71
CardinalityInternal(CardinalityOptions options) const72 int64_t CardinalityInternal(CardinalityOptions options) const override {
73 return input_->Cardinality(options);
74 }
75
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const76 Status Get(OpKernelContext* ctx, int64 index,
77 std::vector<Tensor>* out_tensors) const override {
78 return input_->Get(ctx, index, out_tensors);
79 }
80
DebugString() const81 string DebugString() const override {
82 return name_utils::DatasetDebugString(kDatasetType);
83 }
84
InputDatasets(std::vector<const DatasetBase * > * inputs) const85 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
86 inputs->push_back(input_);
87 return OkStatus();
88 }
89
CheckExternalState() const90 Status CheckExternalState() const override {
91 return input_->CheckExternalState();
92 }
93
94 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const95 Status AsGraphDefInternal(SerializationContext* ctx,
96 DatasetGraphDefBuilder* b,
97 Node** output) const override {
98 Node* input_graph_node = nullptr;
99 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
100 AttrValue serialized_options_attr;
101 b->BuildAttrValue(serialized_options_, &serialized_options_attr);
102 TF_RETURN_IF_ERROR(b->AddDataset(
103 this, {input_graph_node},
104 {std::make_pair(kSerializedOptions, serialized_options_attr)}, output));
105 return OkStatus();
106 }
107
108 private:
109 const DatasetBase* input_;
110 const tstring serialized_options_;
111 };
112
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)113 void OptionsDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase** output) {
114 DatasetBase* input;
115 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input));
116 *output = new Dataset(ctx, input, serialized_options_);
117 }
118
OptionsDatasetOp(OpKernelConstruction * ctx)119 OptionsDatasetOp::OptionsDatasetOp(OpKernelConstruction* ctx)
120 : DatasetOpKernel(ctx) {
121 OP_REQUIRES_OK(ctx, ctx->GetAttr(kSerializedOptions, &serialized_options_));
122 }
123
124 namespace {
125 REGISTER_KERNEL_BUILDER(Name("OptionsDataset").Device(DEVICE_CPU).Priority(2),
126 OptionsDatasetOp);
127 REGISTER_KERNEL_BUILDER(Name("OptionsDataset")
128 .Device(DEVICE_GPU)
129 .HostMemory("input_dataset")
130 .HostMemory("handle")
131 .Priority(1),
132 OptionsDatasetOp);
133 } // namespace
134 } // namespace data
135 } // namespace tensorflow
136