xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/data/options_dataset_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/options_dataset_op.h"
16 
17 #include "absl/memory/memory.h"
18 #include "tensorflow/core/data/name_utils.h"
19 #include "tensorflow/core/framework/dataset.h"
20 #include "tensorflow/core/framework/dataset_options.pb.h"
21 #include "tensorflow/core/framework/partial_tensor_shape.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/profiler/lib/traceme.h"
24 
25 namespace tensorflow {
26 namespace data {
27 
28 /* static */ constexpr const char* const OptionsDatasetOp::kDatasetType;
29 /* static */ constexpr const char* const OptionsDatasetOp::kInputDataset;
30 /* static */ constexpr const char* const OptionsDatasetOp::kOutputTypes;
31 /* static */ constexpr const char* const OptionsDatasetOp::kOutputShapes;
32 /* static */ constexpr const char* const OptionsDatasetOp::kSerializedOptions;
33 
34 class OptionsDatasetOp::Dataset : public DatasetBase {
35  public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,const string & serialized_options)36   Dataset(OpKernelContext* ctx, const DatasetBase* input,
37           const string& serialized_options)
38       : DatasetBase(DatasetContext(ctx)),
39         input_(input),
40         serialized_options_(serialized_options) {
41     input_->Ref();
42     Options options;
43     OP_REQUIRES(ctx, options.ParseFromString(serialized_options),
44                 errors::InvalidArgument(absl::StrCat(
45                     "Could not parse ", OptionsDatasetOp::kSerializedOptions,
46                     " as valid Options.")));
47     set_options(options);
48   }
49 
~Dataset()50   ~Dataset() override { input_->Unref(); }
51 
MakeIteratorInternal(const string & prefix) const52   std::unique_ptr<IteratorBase> MakeIteratorInternal(
53       const string& prefix) const override {
54     DCHECK(false) << "OptionsDatasetOp::Dataset::MakeIteratorInternal is not "
55                      "expected to be called because it is supposed to forward "
56                      "the iterator to its input dataset(s).";
57     LOG(ERROR) << "Datasets of type " << type_string()
58                << " forwards its iterator to its input dataset. "
59                   "`MakeIteratorInternal` is not implemented.";
60     return nullptr;
61   }
62 
output_dtypes() const63   const DataTypeVector& output_dtypes() const override {
64     return input_->output_dtypes();
65   }
output_shapes() const66   const std::vector<PartialTensorShape>& output_shapes() const override {
67     return input_->output_shapes();
68   }
69 
CardinalityInternal() const70   int64_t CardinalityInternal() const override { return input_->Cardinality(); }
71 
CardinalityInternal(CardinalityOptions options) const72   int64_t CardinalityInternal(CardinalityOptions options) const override {
73     return input_->Cardinality(options);
74   }
75 
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const76   Status Get(OpKernelContext* ctx, int64 index,
77              std::vector<Tensor>* out_tensors) const override {
78     return input_->Get(ctx, index, out_tensors);
79   }
80 
DebugString() const81   string DebugString() const override {
82     return name_utils::DatasetDebugString(kDatasetType);
83   }
84 
InputDatasets(std::vector<const DatasetBase * > * inputs) const85   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
86     inputs->push_back(input_);
87     return OkStatus();
88   }
89 
CheckExternalState() const90   Status CheckExternalState() const override {
91     return input_->CheckExternalState();
92   }
93 
94  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const95   Status AsGraphDefInternal(SerializationContext* ctx,
96                             DatasetGraphDefBuilder* b,
97                             Node** output) const override {
98     Node* input_graph_node = nullptr;
99     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
100     AttrValue serialized_options_attr;
101     b->BuildAttrValue(serialized_options_, &serialized_options_attr);
102     TF_RETURN_IF_ERROR(b->AddDataset(
103         this, {input_graph_node},
104         {std::make_pair(kSerializedOptions, serialized_options_attr)}, output));
105     return OkStatus();
106   }
107 
108  private:
109   const DatasetBase* input_;
110   const tstring serialized_options_;
111 };
112 
MakeDataset(OpKernelContext * ctx,DatasetBase ** output)113 void OptionsDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase** output) {
114   DatasetBase* input;
115   OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input));
116   *output = new Dataset(ctx, input, serialized_options_);
117 }
118 
OptionsDatasetOp(OpKernelConstruction * ctx)119 OptionsDatasetOp::OptionsDatasetOp(OpKernelConstruction* ctx)
120     : DatasetOpKernel(ctx) {
121   OP_REQUIRES_OK(ctx, ctx->GetAttr(kSerializedOptions, &serialized_options_));
122 }
123 
124 namespace {
125 REGISTER_KERNEL_BUILDER(Name("OptionsDataset").Device(DEVICE_CPU).Priority(2),
126                         OptionsDatasetOp);
127 REGISTER_KERNEL_BUILDER(Name("OptionsDataset")
128                             .Device(DEVICE_GPU)
129                             .HostMemory("input_dataset")
130                             .HostMemory("handle")
131                             .Priority(1),
132                         OptionsDatasetOp);
133 }  // namespace
134 }  // namespace data
135 }  // namespace tensorflow
136