1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/map_dataset_op.h"
16
17 #include "tensorflow/core/common_runtime/function.h"
18 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
19 #include "tensorflow/core/data/captured_function.h"
20 #include "tensorflow/core/data/dataset_utils.h"
21 #include "tensorflow/core/data/name_utils.h"
22 #include "tensorflow/core/framework/partial_tensor_shape.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/lib/random/random.h"
25
26 namespace tensorflow {
27 namespace data {
28
29 // See documentation in ../../ops/dataset_ops.cc for a high-level
30 // description of the following op.
31
32 /* static */ constexpr const char* const MapDatasetOp::kDatasetType;
33 /* static */ constexpr const char* const MapDatasetOp::kInputDataset;
34 /* static */ constexpr const char* const MapDatasetOp::kOtherArguments;
35 /* static */ constexpr const char* const MapDatasetOp::kFunc;
36 /* static */ constexpr const char* const MapDatasetOp::kTarguments;
37 /* static */ constexpr const char* const MapDatasetOp::kOutputTypes;
38 /* static */ constexpr const char* const MapDatasetOp::kOutputShapes;
39 /* static */ constexpr const char* const MapDatasetOp::kUseInterOpParallelism;
40 /* static */ constexpr const char* const MapDatasetOp::kPreserveCardinality;
41
42 class MapDatasetOp::Dataset : public DatasetBase {
43 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,std::unique_ptr<CapturedFunction> captured_func,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes,bool preserve_cardinality)44 Dataset(OpKernelContext* ctx, const DatasetBase* input,
45 std::unique_ptr<CapturedFunction> captured_func,
46 const DataTypeVector& output_types,
47 const std::vector<PartialTensorShape>& output_shapes,
48 bool preserve_cardinality)
49 : DatasetBase(DatasetContext(ctx)),
50 input_(input),
51 preserve_cardinality_(preserve_cardinality),
52 captured_func_(std::move(captured_func)),
53 output_types_(output_types),
54 output_shapes_(output_shapes) {
55 input_->Ref();
56 }
57
~Dataset()58 ~Dataset() override { input_->Unref(); }
59
MakeIteratorInternal(const string & prefix) const60 std::unique_ptr<IteratorBase> MakeIteratorInternal(
61 const string& prefix) const override {
62 return std::make_unique<Iterator>(Iterator::Params{
63 this, name_utils::IteratorPrefix(kDatasetType, prefix)});
64 }
65
output_dtypes() const66 const DataTypeVector& output_dtypes() const override { return output_types_; }
67
output_shapes() const68 const std::vector<PartialTensorShape>& output_shapes() const override {
69 return output_shapes_;
70 }
71
DebugString() const72 string DebugString() const override {
73 return name_utils::DatasetDebugString(kDatasetType);
74 }
75
CardinalityInternal() const76 int64_t CardinalityInternal() const override {
77 if (preserve_cardinality_) {
78 return input_->Cardinality();
79 } else {
80 return kUnknownCardinality;
81 }
82 }
83
CardinalityInternal(CardinalityOptions options) const84 int64_t CardinalityInternal(CardinalityOptions options) const override {
85 if (preserve_cardinality_) {
86 return input_->Cardinality(options);
87 } else {
88 return kUnknownCardinality;
89 }
90 }
91
InputDatasets(std::vector<const DatasetBase * > * inputs) const92 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
93 inputs->push_back(input_);
94 return OkStatus();
95 }
96
CheckExternalState() const97 Status CheckExternalState() const override {
98 TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
99 return input_->CheckExternalState();
100 }
101
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const102 Status Get(OpKernelContext* ctx, int64 index,
103 std::vector<Tensor>* out_tensors) const override {
104 TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index));
105 std::vector<Tensor> args;
106 TF_RETURN_IF_ERROR(input_->Get(ctx, index, &args));
107 if (!instantiated_captured_func_) {
108 TF_RETURN_IF_ERROR(
109 captured_func_->Instantiate(InstantiateCapturedFunctionParams(ctx),
110 &instantiated_captured_func_));
111 }
112 return instantiated_captured_func_->RunInstantiated(args, out_tensors);
113 }
114
115 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const116 Status AsGraphDefInternal(SerializationContext* ctx,
117 DatasetGraphDefBuilder* b,
118 Node** output) const override {
119 Node* input_graph_node = nullptr;
120 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
121
122 std::vector<Node*> other_arguments;
123 DataTypeVector other_arguments_types;
124 TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
125 &other_arguments_types));
126
127 // Attr: f
128 AttrValue f_attr;
129 b->BuildAttrValue(captured_func_->func(), &f_attr);
130
131 // Attr: Targuments
132 AttrValue other_arguments_types_attr;
133 b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
134
135 // Attr: use_inter_op_parallelism
136 AttrValue use_inter_op_parallelism_attr;
137 b->BuildAttrValue(captured_func_->use_inter_op_parallelism(),
138 &use_inter_op_parallelism_attr);
139
140 // Attr: preserve_cardinality
141 AttrValue preserve_cardinality_attr;
142 b->BuildAttrValue(preserve_cardinality_, &preserve_cardinality_attr);
143
144 TF_RETURN_IF_ERROR(b->AddDataset(
145 this, {std::make_pair(0, input_graph_node)}, // Single tensor inputs.
146 {std::make_pair(1, other_arguments)}, // Tensor list inputs.
147 {std::make_pair(kFunc, f_attr),
148 std::make_pair(kTarguments, other_arguments_types_attr),
149 std::make_pair(kUseInterOpParallelism, use_inter_op_parallelism_attr),
150 std::make_pair(kPreserveCardinality,
151 preserve_cardinality_attr)}, // Attrs
152 output));
153 return OkStatus();
154 }
155
156 private:
157 class Iterator : public DatasetIterator<Dataset> {
158 public:
Iterator(const Params & params)159 explicit Iterator(const Params& params)
160 : DatasetIterator<Dataset>(params) {}
161
Initialize(IteratorContext * ctx)162 Status Initialize(IteratorContext* ctx) override {
163 TF_RETURN_IF_ERROR(
164 dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
165 return dataset()->captured_func_->Instantiate(
166 ctx, &instantiated_captured_func_);
167 }
168
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)169 Status GetNextInternal(IteratorContext* ctx,
170 std::vector<Tensor>* out_tensors,
171 bool* end_of_sequence) override {
172 // NOTE(mrry): This method is thread-safe as long as
173 // `input_impl_` and `f` are thread-safe. However, if multiple
174 // threads enter this method, outputs may be observed in a
175 // non-deterministic order.
176
177 std::vector<Tensor> args;
178 TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &args, end_of_sequence));
179 if (*end_of_sequence) {
180 return OkStatus();
181 }
182
183 Status s = instantiated_captured_func_->Run(ctx, std::move(args),
184 out_tensors, model_node());
185 if (errors::IsOutOfRange(s)) {
186 if (dataset()->preserve_cardinality_) {
187 // To guarantee that the transformation preserves the cardinality of
188 // the dataset, we convert `OutOfRange` to `InvalidArgument` as the
189 // former may be interpreted by a caller as the end of sequence.
190 return errors::InvalidArgument(
191 "Function invocation produced OutOfRangeError: ",
192 s.error_message());
193 } else {
194 // `f` may deliberately raise `errors::OutOfRange` to indicate
195 // that we should terminate the iteration early.
196 *end_of_sequence = true;
197 return OkStatus();
198 }
199 } else {
200 return s;
201 }
202 }
203
204 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const205 std::shared_ptr<model::Node> CreateNode(
206 IteratorContext* ctx, model::Node::Args args) const override {
207 return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1);
208 }
209
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)210 Status SaveInternal(SerializationContext* ctx,
211 IteratorStateWriter* writer) override {
212 TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
213 dataset()->captured_func_->CheckExternalState()));
214 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
215 return OkStatus();
216 }
217
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)218 Status RestoreInternal(IteratorContext* ctx,
219 IteratorStateReader* reader) override {
220 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
221 return OkStatus();
222 }
223
224 private:
225 std::unique_ptr<IteratorBase> input_impl_;
226 std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
227 };
228
229 const DatasetBase* const input_;
230 const bool preserve_cardinality_;
231 const std::unique_ptr<CapturedFunction> captured_func_;
232 const DataTypeVector output_types_;
233 const std::vector<PartialTensorShape> output_shapes_;
234 // This is used for random access provided by Get().
235 mutable std::unique_ptr<InstantiatedCapturedFunction>
236 instantiated_captured_func_;
237 };
238
MapDatasetOp(OpKernelConstruction * ctx)239 MapDatasetOp::MapDatasetOp(OpKernelConstruction* ctx)
240 : UnaryDatasetOpKernel(ctx) {
241 FunctionMetadata::Params params;
242 OP_REQUIRES_OK(ctx, ctx->GetAttr(kUseInterOpParallelism,
243 ¶ms.use_inter_op_parallelism));
244 OP_REQUIRES_OK(ctx,
245 FunctionMetadata::Create(ctx, kFunc, params, &func_metadata_));
246 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
247 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
248 OP_REQUIRES_OK(ctx,
249 ctx->GetAttr(kPreserveCardinality, &preserve_cardinality_));
250 }
251
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)252 void MapDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
253 DatasetBase** output) {
254 std::unique_ptr<CapturedFunction> captured_func;
255 OP_REQUIRES_OK(ctx,
256 CapturedFunction::Create(ctx, func_metadata_, kOtherArguments,
257 &captured_func));
258
259 *output = new Dataset(ctx, input, std::move(captured_func), output_types_,
260 output_shapes_, preserve_cardinality_);
261 }
262
263 namespace {
264
265 REGISTER_KERNEL_BUILDER(Name("MapDataset").Device(DEVICE_CPU), MapDatasetOp);
266 REGISTER_KERNEL_BUILDER(Name("ExperimentalMapDataset")
267 .Device(DEVICE_GPU)
268 .HostMemory("input_dataset")
269 .HostMemory("handle"),
270 MapDatasetOp);
271 REGISTER_INPUT_COLOCATION_EXEMPTION("MapDataset");
272
273 } // namespace
274 } // namespace data
275 } // namespace tensorflow
276