xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/data/map_dataset_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/map_dataset_op.h"
16 
17 #include "tensorflow/core/common_runtime/function.h"
18 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
19 #include "tensorflow/core/data/captured_function.h"
20 #include "tensorflow/core/data/dataset_utils.h"
21 #include "tensorflow/core/data/name_utils.h"
22 #include "tensorflow/core/framework/partial_tensor_shape.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/lib/random/random.h"
25 
26 namespace tensorflow {
27 namespace data {
28 
29 // See documentation in ../../ops/dataset_ops.cc for a high-level
30 // description of the following op.
31 
32 /* static */ constexpr const char* const MapDatasetOp::kDatasetType;
33 /* static */ constexpr const char* const MapDatasetOp::kInputDataset;
34 /* static */ constexpr const char* const MapDatasetOp::kOtherArguments;
35 /* static */ constexpr const char* const MapDatasetOp::kFunc;
36 /* static */ constexpr const char* const MapDatasetOp::kTarguments;
37 /* static */ constexpr const char* const MapDatasetOp::kOutputTypes;
38 /* static */ constexpr const char* const MapDatasetOp::kOutputShapes;
39 /* static */ constexpr const char* const MapDatasetOp::kUseInterOpParallelism;
40 /* static */ constexpr const char* const MapDatasetOp::kPreserveCardinality;
41 
42 class MapDatasetOp::Dataset : public DatasetBase {
43  public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,std::unique_ptr<CapturedFunction> captured_func,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes,bool preserve_cardinality)44   Dataset(OpKernelContext* ctx, const DatasetBase* input,
45           std::unique_ptr<CapturedFunction> captured_func,
46           const DataTypeVector& output_types,
47           const std::vector<PartialTensorShape>& output_shapes,
48           bool preserve_cardinality)
49       : DatasetBase(DatasetContext(ctx)),
50         input_(input),
51         preserve_cardinality_(preserve_cardinality),
52         captured_func_(std::move(captured_func)),
53         output_types_(output_types),
54         output_shapes_(output_shapes) {
55     input_->Ref();
56   }
57 
~Dataset()58   ~Dataset() override { input_->Unref(); }
59 
MakeIteratorInternal(const string & prefix) const60   std::unique_ptr<IteratorBase> MakeIteratorInternal(
61       const string& prefix) const override {
62     return std::make_unique<Iterator>(Iterator::Params{
63         this, name_utils::IteratorPrefix(kDatasetType, prefix)});
64   }
65 
output_dtypes() const66   const DataTypeVector& output_dtypes() const override { return output_types_; }
67 
output_shapes() const68   const std::vector<PartialTensorShape>& output_shapes() const override {
69     return output_shapes_;
70   }
71 
DebugString() const72   string DebugString() const override {
73     return name_utils::DatasetDebugString(kDatasetType);
74   }
75 
CardinalityInternal() const76   int64_t CardinalityInternal() const override {
77     if (preserve_cardinality_) {
78       return input_->Cardinality();
79     } else {
80       return kUnknownCardinality;
81     }
82   }
83 
CardinalityInternal(CardinalityOptions options) const84   int64_t CardinalityInternal(CardinalityOptions options) const override {
85     if (preserve_cardinality_) {
86       return input_->Cardinality(options);
87     } else {
88       return kUnknownCardinality;
89     }
90   }
91 
InputDatasets(std::vector<const DatasetBase * > * inputs) const92   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
93     inputs->push_back(input_);
94     return OkStatus();
95   }
96 
CheckExternalState() const97   Status CheckExternalState() const override {
98     TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
99     return input_->CheckExternalState();
100   }
101 
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const102   Status Get(OpKernelContext* ctx, int64 index,
103              std::vector<Tensor>* out_tensors) const override {
104     TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index));
105     std::vector<Tensor> args;
106     TF_RETURN_IF_ERROR(input_->Get(ctx, index, &args));
107     if (!instantiated_captured_func_) {
108       TF_RETURN_IF_ERROR(
109           captured_func_->Instantiate(InstantiateCapturedFunctionParams(ctx),
110                                       &instantiated_captured_func_));
111     }
112     return instantiated_captured_func_->RunInstantiated(args, out_tensors);
113   }
114 
115  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const116   Status AsGraphDefInternal(SerializationContext* ctx,
117                             DatasetGraphDefBuilder* b,
118                             Node** output) const override {
119     Node* input_graph_node = nullptr;
120     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
121 
122     std::vector<Node*> other_arguments;
123     DataTypeVector other_arguments_types;
124     TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
125                                                   &other_arguments_types));
126 
127     // Attr: f
128     AttrValue f_attr;
129     b->BuildAttrValue(captured_func_->func(), &f_attr);
130 
131     // Attr: Targuments
132     AttrValue other_arguments_types_attr;
133     b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
134 
135     // Attr: use_inter_op_parallelism
136     AttrValue use_inter_op_parallelism_attr;
137     b->BuildAttrValue(captured_func_->use_inter_op_parallelism(),
138                       &use_inter_op_parallelism_attr);
139 
140     // Attr: preserve_cardinality
141     AttrValue preserve_cardinality_attr;
142     b->BuildAttrValue(preserve_cardinality_, &preserve_cardinality_attr);
143 
144     TF_RETURN_IF_ERROR(b->AddDataset(
145         this, {std::make_pair(0, input_graph_node)},  // Single tensor inputs.
146         {std::make_pair(1, other_arguments)},         // Tensor list inputs.
147         {std::make_pair(kFunc, f_attr),
148          std::make_pair(kTarguments, other_arguments_types_attr),
149          std::make_pair(kUseInterOpParallelism, use_inter_op_parallelism_attr),
150          std::make_pair(kPreserveCardinality,
151                         preserve_cardinality_attr)},  // Attrs
152         output));
153     return OkStatus();
154   }
155 
156  private:
157   class Iterator : public DatasetIterator<Dataset> {
158    public:
Iterator(const Params & params)159     explicit Iterator(const Params& params)
160         : DatasetIterator<Dataset>(params) {}
161 
Initialize(IteratorContext * ctx)162     Status Initialize(IteratorContext* ctx) override {
163       TF_RETURN_IF_ERROR(
164           dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
165       return dataset()->captured_func_->Instantiate(
166           ctx, &instantiated_captured_func_);
167     }
168 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)169     Status GetNextInternal(IteratorContext* ctx,
170                            std::vector<Tensor>* out_tensors,
171                            bool* end_of_sequence) override {
172       // NOTE(mrry): This method is thread-safe as long as
173       // `input_impl_` and `f` are thread-safe. However, if multiple
174       // threads enter this method, outputs may be observed in a
175       // non-deterministic order.
176 
177       std::vector<Tensor> args;
178       TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &args, end_of_sequence));
179       if (*end_of_sequence) {
180         return OkStatus();
181       }
182 
183       Status s = instantiated_captured_func_->Run(ctx, std::move(args),
184                                                   out_tensors, model_node());
185       if (errors::IsOutOfRange(s)) {
186         if (dataset()->preserve_cardinality_) {
187           // To guarantee that the transformation preserves the cardinality of
188           // the dataset, we convert `OutOfRange` to `InvalidArgument` as the
189           // former may be interpreted by a caller as the end of sequence.
190           return errors::InvalidArgument(
191               "Function invocation produced OutOfRangeError: ",
192               s.error_message());
193         } else {
194           // `f` may deliberately raise `errors::OutOfRange` to indicate
195           // that we should terminate the iteration early.
196           *end_of_sequence = true;
197           return OkStatus();
198         }
199       } else {
200         return s;
201       }
202     }
203 
204    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const205     std::shared_ptr<model::Node> CreateNode(
206         IteratorContext* ctx, model::Node::Args args) const override {
207       return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1);
208     }
209 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)210     Status SaveInternal(SerializationContext* ctx,
211                         IteratorStateWriter* writer) override {
212       TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
213           dataset()->captured_func_->CheckExternalState()));
214       TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
215       return OkStatus();
216     }
217 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)218     Status RestoreInternal(IteratorContext* ctx,
219                            IteratorStateReader* reader) override {
220       TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
221       return OkStatus();
222     }
223 
224    private:
225     std::unique_ptr<IteratorBase> input_impl_;
226     std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
227   };
228 
229   const DatasetBase* const input_;
230   const bool preserve_cardinality_;
231   const std::unique_ptr<CapturedFunction> captured_func_;
232   const DataTypeVector output_types_;
233   const std::vector<PartialTensorShape> output_shapes_;
234   // This is used for random access provided by Get().
235   mutable std::unique_ptr<InstantiatedCapturedFunction>
236       instantiated_captured_func_;
237 };
238 
MapDatasetOp(OpKernelConstruction * ctx)239 MapDatasetOp::MapDatasetOp(OpKernelConstruction* ctx)
240     : UnaryDatasetOpKernel(ctx) {
241   FunctionMetadata::Params params;
242   OP_REQUIRES_OK(ctx, ctx->GetAttr(kUseInterOpParallelism,
243                                    &params.use_inter_op_parallelism));
244   OP_REQUIRES_OK(ctx,
245                  FunctionMetadata::Create(ctx, kFunc, params, &func_metadata_));
246   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
247   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
248   OP_REQUIRES_OK(ctx,
249                  ctx->GetAttr(kPreserveCardinality, &preserve_cardinality_));
250 }
251 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)252 void MapDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
253                                DatasetBase** output) {
254   std::unique_ptr<CapturedFunction> captured_func;
255   OP_REQUIRES_OK(ctx,
256                  CapturedFunction::Create(ctx, func_metadata_, kOtherArguments,
257                                           &captured_func));
258 
259   *output = new Dataset(ctx, input, std::move(captured_func), output_types_,
260                         output_shapes_, preserve_cardinality_);
261 }
262 
263 namespace {
264 
265 REGISTER_KERNEL_BUILDER(Name("MapDataset").Device(DEVICE_CPU), MapDatasetOp);
266 REGISTER_KERNEL_BUILDER(Name("ExperimentalMapDataset")
267                             .Device(DEVICE_GPU)
268                             .HostMemory("input_dataset")
269                             .HostMemory("handle"),
270                         MapDatasetOp);
271 REGISTER_INPUT_COLOCATION_EXEMPTION("MapDataset");
272 
273 }  // namespace
274 }  // namespace data
275 }  // namespace tensorflow
276