xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/data/iterator_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/iterator_ops.h"
16 
17 #include <cstdint>
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/memory/memory.h"
24 #include "absl/time/time.h"
25 #include "tensorflow/core/activity_watcher/activity.h"
26 #include "tensorflow/core/activity_watcher/activity_utils.h"
27 #include "tensorflow/core/common_runtime/graph_constructor.h"
28 #include "tensorflow/core/common_runtime/graph_runner.h"
29 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
30 #include "tensorflow/core/common_runtime/renamed_device.h"
31 #include "tensorflow/core/common_runtime/threadpool_device.h"
32 #include "tensorflow/core/data/captured_function.h"
33 #include "tensorflow/core/data/dataset_utils.h"
34 #include "tensorflow/core/data/finalization_utils.h"
35 #include "tensorflow/core/data/metric_utils.h"
36 #include "tensorflow/core/data/root_dataset.h"
37 #include "tensorflow/core/data/serialization_utils.h"
38 #include "tensorflow/core/data/utils.h"
39 #include "tensorflow/core/framework/cancellation.h"
40 #include "tensorflow/core/framework/function.h"
41 #include "tensorflow/core/framework/metrics.h"
42 #include "tensorflow/core/framework/op_kernel.h"
43 #include "tensorflow/core/framework/partial_tensor_shape.h"
44 #include "tensorflow/core/framework/resource_op_kernel.h"
45 #include "tensorflow/core/framework/stats_aggregator.h"
46 #include "tensorflow/core/framework/tensor.h"
47 #include "tensorflow/core/framework/types.h"
48 #include "tensorflow/core/framework/variant_op_registry.h"
49 #include "tensorflow/core/framework/variant_tensor_data.h"
50 #include "tensorflow/core/kernels/data/optional_ops.h"
51 #include "tensorflow/core/kernels/ops_util.h"
52 #include "tensorflow/core/lib/core/errors.h"
53 #include "tensorflow/core/lib/core/refcount.h"
54 #include "tensorflow/core/lib/core/threadpool.h"
55 #include "tensorflow/core/lib/gtl/cleanup.h"
56 #include "tensorflow/core/lib/random/random.h"
57 #include "tensorflow/core/lib/strings/strcat.h"
58 #include "tensorflow/core/lib/strings/stringprintf.h"
59 #include "tensorflow/core/platform/casts.h"
60 #include "tensorflow/core/platform/env.h"
61 #include "tensorflow/core/platform/errors.h"
62 #include "tensorflow/core/platform/mem.h"
63 #include "tensorflow/core/platform/mutex.h"
64 #include "tensorflow/core/platform/refcount.h"
65 #include "tensorflow/core/platform/resource.h"
66 #include "tensorflow/core/profiler/lib/traceme.h"
67 #include "tensorflow/core/profiler/lib/traceme_encode.h"
68 #include "tensorflow/core/public/session_options.h"
69 
70 namespace tensorflow {
71 namespace data {
72 namespace {
73 
74 // See documentation in ../../ops/dataset_ops.cc for a high-level
75 // description of the following ops.
76 
77 const char kAnonymousIterator[] = "AnonymousIterator";
78 const char kAnonymousIteratorV2[] = "AnonymousIteratorV2";
79 const char kAnonymousIteratorV3[] = "AnonymousIteratorV3";
80 const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
81 const char kOutputShapes[] = "output_shapes";
82 const char kOutputTypes[] = "output_types";
83 
84 }  // namespace
85 
86 /* static */ constexpr const char* const
87     SerializeIteratorOp::kExternalStatePolicy;
88 
IteratorResource(Env * env,const DataTypeVector & output_dtypes,const std::vector<PartialTensorShape> & output_shapes,std::unique_ptr<DeviceMgr> device_mgr,std::unique_ptr<FunctionLibraryDefinition> flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * flr)89 IteratorResource::IteratorResource(
90     Env* env, const DataTypeVector& output_dtypes,
91     const std::vector<PartialTensorShape>& output_shapes,
92     std::unique_ptr<DeviceMgr> device_mgr,
93     std::unique_ptr<FunctionLibraryDefinition> flib_def,
94     std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
95     FunctionLibraryRuntime* flr)
96     : metrics_collector_(flr->device()->device_type(), *env),
97       unbounded_thread_pool_(env, "tf_data_iterator_resource"),
98       device_mgr_(std::move(device_mgr)),
99       iterator_state_(std::make_shared<State>(std::move(flib_def),
100                                               std::move(pflr), flr,
101                                               /*iterator=*/nullptr)),
102       output_dtypes_(output_dtypes),
103       output_shapes_(output_shapes) {
104   VLOG(2) << "creating iterator resource";
105 }
106 
~IteratorResource()107 IteratorResource::~IteratorResource() {
108   VLOG(2) << "destroying iterator resource";
109 }
110 
GetNext(OpKernelContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)111 Status IteratorResource::GetNext(OpKernelContext* ctx,
112                                  std::vector<Tensor>* out_tensors,
113                                  bool* end_of_sequence) {
114   std::shared_ptr<State> captured_state;
115   {
116     tf_shared_lock l(mu_);
117     captured_state = iterator_state_;
118   }
119   if (!captured_state->iterator()) {
120     return errors::FailedPrecondition(
121         "GetNext() failed because the iterator has not been initialized. "
122         "Ensure that you have run the initializer operation for this iterator "
123         "before getting the next element.");
124   }
125   IteratorContext::Params params(ctx);
126   params.flr = captured_state->flr();
127   params.function_handle_cache = captured_state->function_handle_cache();
128   params.resource_mgr = captured_state->resource_mgr();
129   params.thread_factory = unbounded_thread_pool_.get_thread_factory();
130   params.thread_pool = &unbounded_thread_pool_;
131   params.cancellation_manager = captured_state->cancellation_manager();
132   std::function<void()> deregister_fn;
133   TF_RETURN_IF_ERROR(RegisterCancellationCallback(
134       ctx->cancellation_manager(),
135       [cm = params.cancellation_manager]() { cm->StartCancel(); },
136       &deregister_fn));
137   auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
138 
139   const absl::Time start_time = metrics_collector_.RecordStart();
140   auto iterator_ = captured_state->iterator();
141   auto status = iterator_->GetNext(IteratorContext(std::move(params)),
142                                    out_tensors, end_of_sequence);
143   metrics_collector_.RecordStop(start_time, *out_tensors);
144   return status;
145 }
146 
Save(SerializationContext * ctx,IteratorStateWriter * writer)147 Status IteratorResource::Save(SerializationContext* ctx,
148                               IteratorStateWriter* writer) {
149   std::shared_ptr<State> captured_state;
150   {
151     tf_shared_lock l(mu_);
152     captured_state = iterator_state_;
153   }
154   auto iterator_ = captured_state->iterator();
155   if (iterator_) {
156     return iterator_->Save(ctx, writer);
157   }
158   return errors::FailedPrecondition(
159       "Save() failed because the iterator has not been initialized. Ensure "
160       "that you have run the initializer operation for this iterator before "
161       "saving it.");
162 }
163 
Restore(OpKernelContext * ctx,IteratorStateReader * reader)164 Status IteratorResource::Restore(OpKernelContext* ctx,
165                                  IteratorStateReader* reader) {
166   const DatasetBase* dataset;
167   std::shared_ptr<State> new_state;
168   const DatasetBase* input_dataset;
169   {
170     tf_shared_lock l(mu_);
171     if (!iterator_state_->iterator()) {
172       return errors::FailedPrecondition(
173           "Restore() failed because the iterator has not been initialized. "
174           "Ensure that you have run the initializer operation for this "
175           "iterator before restoring it.");
176     }
177     auto iterator_ = iterator_state_->iterator();
178     dataset = iterator_->dataset();
179     // Hang onto a reference until we've created the new iterator, which will
180     // then hold its own reference to keep the dataset alive.
181     dataset->Ref();
182     new_state =
183         std::make_shared<State>(iterator_state_->flib_def(),
184                                 iterator_state_->pflr(), iterator_state_->flr(),
185                                 /*iterator=*/nullptr);
186     input_dataset = iterator_state_->dataset();
187   }
188   core::ScopedUnref scoped_unref(dataset);
189   IteratorContext::Params params(ctx);
190   params.flr = new_state->flr();
191   params.function_handle_cache = new_state->function_handle_cache();
192   params.resource_mgr = new_state->resource_mgr();
193   params.thread_factory = unbounded_thread_pool_.get_thread_factory();
194   params.thread_pool = &unbounded_thread_pool_;
195   params.cancellation_manager = new_state->cancellation_manager();
196   std::function<void()> deregister_fn;
197   TF_RETURN_IF_ERROR(RegisterCancellationCallback(
198       ctx->cancellation_manager(),
199       [cm = params.cancellation_manager]() { cm->StartCancel(); },
200       &deregister_fn));
201   auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
202   std::unique_ptr<IteratorBase> iterator_base;
203   TF_RETURN_IF_ERROR(dataset->MakeIteratorFromCheckpoint(
204       IteratorContext(std::move(params)), "Iterator", reader, &iterator_base));
205   new_state->DowncastAndSetIteratorAndDataset(std::move(iterator_base),
206                                               input_dataset);
207 
208   mutex_lock l(mu_);
209   std::swap(iterator_state_, new_state);
210   return OkStatus();
211 }
212 
SetIteratorFromDataset(OpKernelContext * ctx,const DatasetBase * dataset)213 Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
214                                                 const DatasetBase* dataset) {
215   std::shared_ptr<State> new_state;
216   {
217     tf_shared_lock l(mu_);
218     new_state =
219         std::make_shared<State>(iterator_state_->flib_def(),
220                                 iterator_state_->pflr(), iterator_state_->flr(),
221                                 /*iterator=*/nullptr);
222   }
223 
224   // Create new iterator.
225   IteratorContext::Params params(ctx);
226   params.flr = new_state->flr();
227   params.function_handle_cache = new_state->function_handle_cache();
228   params.resource_mgr = new_state->resource_mgr();
229   params.thread_factory = unbounded_thread_pool_.get_thread_factory();
230   params.thread_pool = &unbounded_thread_pool_;
231   params.cancellation_manager = new_state->cancellation_manager();
232   std::function<void()> deregister_fn;
233   TF_RETURN_IF_ERROR(RegisterCancellationCallback(
234       ctx->cancellation_manager(),
235       [cm = params.cancellation_manager]() { cm->StartCancel(); },
236       &deregister_fn));
237   auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
238 
239   std::unique_ptr<IteratorBase> iterator;
240   if (ctx->function_library()->device()->device_type() == DEVICE_CPU) {
241     DatasetBase* finalized_dataset;
242     TF_ASSIGN_OR_RETURN(finalized_dataset, GetFinalizedDataset(ctx, dataset));
243     TF_RETURN_IF_ERROR(finalized_dataset->MakeIterator(
244         IteratorContext(std::move(params)),
245         /*parent=*/nullptr, "Iterator", &iterator));
246   } else {
247     TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
248                                              /*parent=*/nullptr, "Iterator",
249                                              &iterator));
250   }
251   TF_RETURN_IF_ERROR(
252       VerifyTypesMatch(output_dtypes_, iterator->output_dtypes()));
253   TF_RETURN_IF_ERROR(
254       VerifyShapesCompatible(output_shapes_, iterator->output_shapes()));
255 
256   new_state->DowncastAndSetIteratorAndDataset(std::move(iterator), dataset);
257 
258   mutex_lock l(mu_);
259   std::swap(iterator_state_, new_state);
260   return OkStatus();
261 }
262 
263 namespace {
264 
265 // Wrapper for encoding/decoding the iterator state stored in a Variant tensor.
266 // The get() method returns an VariantTensorData object which contains all the
267 // state needed to restore a single iterator.
268 //
269 // Usage example:
270 //
271 // Encoding:
272 //
273 //   Tensor t(DT_VARIANT, TensorShape({}));
274 //   t->scalar<Variant>()() = IteratorStateVariant();
275 //
276 // Encode() sets the type_name of the VariantTensorData object to
277 // IteratorStateVariant::TypeName().
278 //
279 // Decoding:
280 //
281 //   Variant v = <VariantTensorDataProto object>;
282 //   DecodeUnaryVariant(&v);
283 //   IteratorStateVariant* wrapper = v.get<IteratorStateVariant>();
284 //   IteratorStateReader reader({wrapper->GetData()});
285 //   iterator_resource->Restore(ctx, &reader);
286 //
287 // The type_name of the VariantTensorData object to be decoded must
288 // match IteratorStateVariant::TypeName().
289 class IteratorStateVariant {
290  public:
IteratorStateVariant()291   IteratorStateVariant() : data_(nullptr) {}
IteratorStateVariant(const IteratorStateVariant & other)292   IteratorStateVariant(const IteratorStateVariant& other) : data_(nullptr) {
293     if (other.data_) {
294       Decode(*other.data_);
295     }
296   }
297   IteratorStateVariant& operator=(IteratorStateVariant&& other) = default;
298   IteratorStateVariant& operator=(const IteratorStateVariant& other) = delete;
299 
300   // Initializes `this` from a VariantTensorData object.
InitializeFromVariantData(std::unique_ptr<VariantTensorData> d)301   Status InitializeFromVariantData(std::unique_ptr<VariantTensorData> d) {
302     data_ = std::move(d);
303     return OkStatus();
304   }
305 
TypeName() const306   string TypeName() const { return kIteratorVariantTypeName; }
Encode(VariantTensorData * data) const307   void Encode(VariantTensorData* data) const { *data = *data_; }
Decode(VariantTensorData data)308   bool Decode(VariantTensorData data) {
309     if (data.type_name() != TypeName()) {
310       return false;
311     }
312     auto tensor_data = std::make_unique<VariantTensorData>();
313     std::swap(*tensor_data, data);
314     data_ = std::move(tensor_data);
315     return true;
316   }
317 
318   // Returns a borrowed pointer to the underlying VariantTensorData.
GetData() const319   const VariantTensorData* GetData() const { return data_.get(); }
320 
DebugString() const321   string DebugString() const {
322     if (data_) {
323       return strings::StrCat("IteratorStateVariant<", data_->DebugString(),
324                              ">");
325     } else {
326       return strings::StrCat("IteratorStateVariant<empty>");
327     }
328   }
329 
330  private:
331   std::unique_ptr<VariantTensorData> data_;
332 };
333 
334 // Register the reader class in the global variant decode_fn registry
335 // so that a Variant containing a serialized representation of iterator state
336 // can be decoded using DecodeUnaryVariant. If we don't do this we will need
337 // to manually decode the returned Variant using MaybeDecodeAndCopy in
338 // DeserializeIteratorOp which is not recommended.
339 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
340                                        kIteratorVariantTypeName);
341 
342 // A helper class that uses a list of IteratorStateVariant objects to represent
343 // the state for an iterator resource. It exposes methods that help with
344 // saving and restoring of this state. Sample usage
345 // Saving:
346 //   IteratorVariantSerializer serializer;
347 //   serializer.InitializeFromIterator(iterator_resource);
348 //   Tensor serialized_t;
349 //   serializer.Serialize(&serialized_t);
350 //
351 // Restoring:
352 //   IteratorVariantSerializer serializer;
353 //   serializer.InitFromTensor(ctx->input(0));
354 //   IteratorStateReader* reader = serializer.GetReader();
355 //   iterator_resource->Restore(ctx, reader);
356 class IteratorVariantSerializer {
357  public:
IteratorVariantSerializer()358   IteratorVariantSerializer() {}
359 
360   // Calls `Save` on the iterator_resource to build up the list of
361   // IteratorStateVariant objects.
InitializeFromIterator(SerializationContext * serialization_ctx,IteratorResource * iterator_resource)362   Status InitializeFromIterator(SerializationContext* serialization_ctx,
363                                 IteratorResource* iterator_resource) {
364     VariantTensorDataWriter writer;
365     TF_RETURN_IF_ERROR(iterator_resource->Save(serialization_ctx, &writer));
366     std::vector<std::unique_ptr<VariantTensorData>> data;
367     writer.ReleaseData(&data);
368     variants_.clear();
369     variants_.reserve(data.size());
370     for (auto& it : data) {
371       IteratorStateVariant v;
372       TF_RETURN_IF_ERROR(v.InitializeFromVariantData(std::move(it)));
373       variants_.push_back(v);
374     }
375     num_tensors_ = variants_.size();
376     can_serialize_ = true;
377     return OkStatus();
378   }
379 
380   // Initializes `this` from `serialized_t` while restoring the iterator state.
InitFromTensor(const Tensor * serialized_t)381   Status InitFromTensor(const Tensor* serialized_t) {
382     int64_t num_tensors = serialized_t->dim_size(0);
383     auto serialized_vec = serialized_t->vec<Variant>();
384     std::vector<const VariantTensorData*> data;
385     data.reserve(num_tensors);
386     for (int i = 0; i < num_tensors; ++i) {
387       auto* w = serialized_vec(i).get<IteratorStateVariant>();
388       if (!w) {
389         return errors::Internal(
390             "Cannot initialize an iterator from tensor ",
391             serialized_vec(i).DebugString(),
392             ". Expected a variant tensor of type IteratorStateVariant");
393       }
394       data.push_back(w->GetData());
395     }
396     reader_ = std::make_unique<VariantTensorDataReader>(data);
397     num_tensors_ = data.size();
398     return OkStatus();
399   }
400 
NumTensors()401   int64_t NumTensors() { return num_tensors_; }
402 
403   // Stores the IteratorStateVariant list into a pre-allocated tensor. Expects
404   // that InitializeFromIterator was called before.
Serialize(Tensor * serialized)405   Status Serialize(Tensor* serialized) {
406     if (!can_serialize_) {
407       return errors::InvalidArgument(
408           "Please call InitializeFromIterator before calling Serialize.");
409     }
410     int64_t size = variants_.size();
411     for (int64_t i = 0; i < size; ++i) {
412       if (variants_[i].GetData() == nullptr) {
413         return errors::Internal(
414             "Cannot serialize an empty IteratorStateVariant");
415       }
416       serialized->vec<Variant>()(i) = variants_[i];
417     }
418     return OkStatus();
419   }
420 
421   // Returns an IteratorStateReader to restore iterator state. Expects that
422   // InitFromTensor was called before.
GetReader()423   IteratorStateReader* GetReader() { return reader_.get(); }
424 
425  private:
426   bool can_serialize_ = false;
427   int64_t num_tensors_;
428   std::vector<IteratorStateVariant> variants_;
429   std::unique_ptr<IteratorStateReader> reader_;
430 };
431 
432 }  // namespace
433 
434 // Note that IteratorHandleOp holds a reference to the resource it creates. If
435 // cleaning up resources with DestroyResourceOp is important, consider creating
436 // resource containers with AnonymousIteratorHandleOp instead.
IteratorHandleOp(OpKernelConstruction * ctx)437 IteratorHandleOp::IteratorHandleOp(OpKernelConstruction* ctx)
438     : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
439   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
440   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
441   OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
442 }
443 
444 // The resource is deleted from the resource manager only when it is private
445 // to kernel. Ideally the resource should be deleted when it is no longer held
446 // by anyone, but it would break backward compatibility.
~IteratorHandleOp()447 IteratorHandleOp::~IteratorHandleOp() {
448   if (resource_ != nullptr) {
449     resource_->Unref();
450     if (cinfo_.resource_is_private_to_kernel()) {
451       if (!cinfo_.resource_manager()
452                ->template Delete<IteratorResource>(cinfo_.container(),
453                                                    cinfo_.name())
454                .ok()) {
455         // Do nothing; the resource can have been deleted by session resets.
456       }
457     }
458   }
459 }
460 
Compute(OpKernelContext * context)461 void IteratorHandleOp::Compute(OpKernelContext* context)
462     TF_LOCKS_EXCLUDED(mu_) {
463   {
464     mutex_lock l(mu_);
465     if (resource_ == nullptr) {
466       FunctionLibraryRuntime* flr;
467       std::unique_ptr<DeviceMgr> device_mgr(nullptr);
468       std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
469       std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
470       // If the iterator is shared then we construct a new FLR, and pass that
471       // in. NOTE(mrry,rohanj): In this case it is not possible to call remote
472       // functions from the iterator. We may add this functionality if there
473       // is sufficient demand, but it will require a significant refactoring.
474       if (!name_.empty()) {
475         flr = CreatePrivateFLR(context, &device_mgr, &flib_def, &pflr);
476       } else {
477         OP_REQUIRES_OK(context, context->function_library()->Clone(
478                                     &flib_def, &pflr, &flr, true));
479       }
480 
481       ResourceMgr* mgr = context->resource_manager();
482       OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
483 
484       IteratorResource* resource;
485       OP_REQUIRES_OK(
486           context,
487           mgr->LookupOrCreate<IteratorResource>(
488               cinfo_.container(), cinfo_.name(), &resource,
489               [context, flr, &device_mgr, &flib_def, &pflr,
490                this](IteratorResource** ret) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
491                 *ret = new IteratorResource(
492                     context->env(), output_dtypes_, output_shapes_,
493                     std::move(device_mgr), std::move(flib_def), std::move(pflr),
494                     flr);
495                 return OkStatus();
496               }));
497 
498       Status s = VerifyResource(resource);
499       if (TF_PREDICT_FALSE(!s.ok())) {
500         resource->Unref();
501         context->SetStatus(s);
502         return;
503       }
504 
505       resource_ = resource;
506     }
507   }
508   OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
509                               context, 0, cinfo_.container(), cinfo_.name(),
510                               TypeIndex::Make<IteratorResource>()));
511 }
512 
VerifyResource(IteratorResource * resource)513 Status IteratorHandleOp::VerifyResource(IteratorResource* resource) {
514   TF_RETURN_IF_ERROR(
515       VerifyTypesMatch(output_dtypes_, resource->output_dtypes()));
516   TF_RETURN_IF_ERROR(
517       VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
518   return OkStatus();
519 }
520 
CreatePrivateFLR(OpKernelContext * ctx,std::unique_ptr<DeviceMgr> * device_mgr,std::unique_ptr<FunctionLibraryDefinition> * flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * pflr)521 FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR(
522     OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr,
523     std::unique_ptr<FunctionLibraryDefinition>* flib_def,
524     std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr) {
525   // Wrap the existing device in order to see any captured resources
526   // in its resource manager. The existing device will outlive the
527   // IteratorResource, because we are storing the IteratorResource
528   // in that device's resource manager.
529 
530   *device_mgr =
531       std::make_unique<StaticDeviceMgr>(RenamedDevice::NewRenamedDevice(
532           ctx->device()->name(), down_cast<Device*>(ctx->device()),
533           false /* owns_underlying */, false /* isolate_session_state */));
534   *flib_def = std::make_unique<FunctionLibraryDefinition>(
535       *ctx->function_library()->GetFunctionLibraryDefinition());
536   const auto* config = ctx->function_library()->config_proto();
537   *pflr = std::make_unique<ProcessFunctionLibraryRuntime>(
538       device_mgr->get(), ctx->env(),
539       /*config=*/config, graph_def_version_, flib_def->get(),
540       config->graph_options().optimizer_options());
541 
542   return (*pflr)->GetFLR(ctx->device()->name());
543 }
544 
545 // Like IteratorHandleOp, but creates handles which are never shared, and does
546 // not hold a reference to these handles. The latter is important for eager
547 // execution, since OpKernel instances generally live as long as the program
548 // running them.
AnonymousIteratorHandleOp(OpKernelConstruction * context)549 AnonymousIteratorHandleOp::AnonymousIteratorHandleOp(
550     OpKernelConstruction* context)
551     : AnonymousResourceOp<IteratorResource>(
552           context,
553 	   /* ref_counting */
554 	  // Always disable it for TfLite environment and let ResourceMgr
555 	  // release it at the end of execution because the TfLite
556 	  // kernel will hold this resource tensor till the end.
557 	  false,
558           /* return_deleter */
559 	  // Alwasy disable it for TfLite environment and let ResourceMgr
560 	  // release it at the end of exectuion because the TfLite
561 	  // kernel will hold this resource tensor till the end.
562           false),
563       graph_def_version_(context->graph_def_version()) {
564   OP_REQUIRES_OK(context, context->GetAttr(kOutputTypes, &output_dtypes_));
565   OP_REQUIRES_OK(context, context->GetAttr(kOutputShapes, &output_shapes_));
566 }
567 
name()568 string AnonymousIteratorHandleOp::name() { return kAnonymousIterator; }
569 
CreateResource(OpKernelContext * ctx,std::unique_ptr<FunctionLibraryDefinition> flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * lib,IteratorResource ** resource)570 Status AnonymousIteratorHandleOp::CreateResource(
571     OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
572     std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
573     FunctionLibraryRuntime* lib, IteratorResource** resource) {
574   std::unique_ptr<DeviceMgr> device_mgr(nullptr);
575   *resource = new IteratorResource(ctx->env(), output_dtypes_, output_shapes_,
576                                    std::move(device_mgr), std::move(flib_def),
577                                    std::move(pflr), lib);
578   return OkStatus();
579 }
580 
HybridAsyncOpKernel(OpKernelConstruction * ctx,const char * background_worker_name)581 HybridAsyncOpKernel::HybridAsyncOpKernel(OpKernelConstruction* ctx,
582                                          const char* background_worker_name)
583     : AsyncOpKernel(ctx),
584       background_worker_(ctx->env(), background_worker_name) {}
585 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)586 void HybridAsyncOpKernel::ComputeAsync(OpKernelContext* ctx,
587                                        DoneCallback done) {
588   background_worker_.Schedule([this, ctx, done = std::move(done)]() {
589     ctx->SetStatus(DoCompute(ctx));
590     done();
591   });
592 }
593 
Compute(OpKernelContext * ctx)594 void HybridAsyncOpKernel::Compute(OpKernelContext* ctx) {
595   ctx->SetStatus(DoCompute(ctx));
596 }
597 
DoCompute(OpKernelContext * ctx)598 Status MakeIteratorOp::DoCompute(OpKernelContext* ctx) {
599   tensorflow::ResourceTagger tag(kTFDataResourceTag,
600                                  ctx->op_kernel().type_string());
601   DatasetBase* dataset;
602   TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
603   IteratorResource* iterator_resource;
604   TF_RETURN_IF_ERROR(
605       LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
606   core::ScopedUnref unref_iterator(iterator_resource);
607   return iterator_resource->SetIteratorFromDataset(ctx, dataset);
608 }
609 
DoCompute(OpKernelContext * ctx)610 Status DeleteIteratorOp::DoCompute(OpKernelContext* ctx) {
611   tensorflow::ResourceTagger tag(kTFDataResourceTag,
612                                  ctx->op_kernel().type_string());
613   const ResourceHandle& handle = ctx->input(0).flat<ResourceHandle>()(0);
614   // The iterator resource is guaranteed to exist because the variant tensor
615   // wrapping the deleter is provided as an unused input to this op, which
616   // guarantees that it has not run yet.
617   return DeleteResource(ctx, handle);
618 }
619 
620 namespace {
621 
622 class ToSingleElementOp : public AsyncOpKernel {
623  public:
ToSingleElementOp(OpKernelConstruction * ctx)624   explicit ToSingleElementOp(OpKernelConstruction* ctx)
625       : AsyncOpKernel(ctx),
626         metrics_collector_(ctx->device()->attributes().device_type(),
627                            *ctx->env()),
628         unbounded_threadpool_(ctx->env(), "tf_data_to_single_element") {
629     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
630     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
631   }
632 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)633   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
634     unbounded_threadpool_.Schedule([this, ctx, done = std::move(done)]() {
635       ctx->SetStatus(DoCompute(ctx));
636       done();
637     });
638   }
639 
Compute(OpKernelContext * ctx)640   void Compute(OpKernelContext* ctx) override {
641     ctx->SetStatus(DoCompute(ctx));
642   }
643 
644  private:
DoCompute(OpKernelContext * ctx)645   Status DoCompute(OpKernelContext* ctx) {
646     profiler::TraceMe traceme(
647         [&] {
648           return profiler::TraceMeEncode("ToSingleElementOp::DoCompute",
649                                          {{"id", ctx->step_id()}});
650         },
651         profiler::kInfo);
652     tensorflow::ResourceTagger tag(kTFDataResourceTag,
653                                    ctx->op_kernel().type_string());
654     DatasetBase* dataset;
655     TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
656 
657     IteratorContext::Params params(ctx);
658     ResourceMgr resource_mgr;
659     params.resource_mgr = &resource_mgr;
660     CancellationManager cancellation_manager(ctx->cancellation_manager());
661     params.cancellation_manager = &cancellation_manager;
662 
663     IteratorContext iter_ctx(std::move(params));
664     std::unique_ptr<IteratorBase> iterator;
665     TF_RETURN_IF_ERROR(dataset->MakeIterator(
666         &iter_ctx, /*parent=*/nullptr, "SingleElementIterator", &iterator));
667 
668     std::vector<Tensor> components;
669     components.reserve(dataset->output_dtypes().size());
670     bool end_of_sequence = false;
671 
672     const absl::Time start_time = metrics_collector_.RecordStart();
673     TF_RETURN_IF_ERROR(
674         iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
675     metrics_collector_.RecordStop(start_time, components);
676 
677     if (end_of_sequence) {
678       return errors::InvalidArgument("Dataset was empty.");
679     }
680     TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
681     TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
682     for (int i = 0; i < components.size(); ++i) {
683       ctx->set_output(i, components[i]);
684     }
685 
686     components.clear();
687     TF_RETURN_IF_ERROR(
688         iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
689     if (!end_of_sequence) {
690       return errors::InvalidArgument("Dataset had more than one element.");
691     }
692     return OkStatus();
693   }
694 
695   IteratorMetricsCollector metrics_collector_;
696   UnboundedThreadPool unbounded_threadpool_;
697   DataTypeVector output_types_;
698   std::vector<PartialTensorShape> output_shapes_;
699 };
700 
701 class OneShotIteratorOp : public AsyncOpKernel {
702  public:
OneShotIteratorOp(OpKernelConstruction * ctx)703   explicit OneShotIteratorOp(OpKernelConstruction* ctx)
704       : AsyncOpKernel(ctx),
705         background_worker_(ctx->env(), "tf_data_one_shot_iterator"),
706         graph_def_version_(ctx->graph_def_version())
707 
708   {
709     string shared_name;
710     OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &shared_name));
711     OP_REQUIRES(ctx, shared_name.empty(),
712                 errors::InvalidArgument("OneShotIteratorOp does not currently "
713                                         "support the 'shared_name' attr."));
714     OP_REQUIRES_OK(ctx,
715                    ctx->GetAttr("dataset_factory", &dataset_factory_func_));
716     OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
717     OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
718   }
719 
~OneShotIteratorOp()720   ~OneShotIteratorOp() override {
721     if (iterator_resource_ != nullptr) {
722       iterator_resource_->Unref();
723       if (!cinfo_.resource_manager()
724                ->Delete<IteratorResource>(cinfo_.container(), cinfo_.name())
725                .ok()) {
726         // Do nothing; the resource can have been deleted by session resets.
727       }
728     }
729   }
730 
731   // NOTE(mrry): This is based on `ResourceOpKernel<T>::Compute()`,
732   // but due to the fact that `ResourceOpKernel<T>::CreateResource()`
733   // does not provide access to the `OpKernelContext*` and we need
734   // this to invoke the factory function, it's not possible to
735   // implement this kernel by implementing `CreateResource()`.
736   // Furthermore, due to the fact that this kernel might block when
737   // running the initialization function, we must implement this
738   // kernel as an async kernel.
ComputeAsync(OpKernelContext * ctx,DoneCallback done)739   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
740     tensorflow::ResourceTagger tag(kTFDataResourceTag,
741                                    ctx->op_kernel().type_string());
742     {
743       mutex_lock l(mu_);
744       if (iterator_resource_ == nullptr && initialization_status_.ok()) {
745         // The initialization thread will call `done`.
746         if (!initialization_started_) {
747           // TODO(mrry): Convert the initialization code to use
748           // callbacks instead of wasting a thread.
749           background_worker_.Schedule([this, ctx, done]() { Init(ctx, done); });
750           initialization_started_ = true;
751         } else {
752           done_callbacks_.emplace_back(ctx, std::move(done));
753         }
754         return;
755       }
756     }
757     ProduceOutput(ctx, done);
758   }
759 
760  private:
Init(OpKernelContext * ctx,const DoneCallback & done)761   void Init(OpKernelContext* ctx, const DoneCallback& done) {
762     IteratorResource* iterator = nullptr;
763     ContainerInfo cinfo;
764     Status s = TryInit(ctx, &iterator, &cinfo);
765 
766     std::vector<std::pair<OpKernelContext*, DoneCallback>> callbacks_to_run;
767     {
768       mutex_lock l(mu_);
769       if (s.ok()) {
770         iterator_resource_ = iterator;
771         cinfo_ = cinfo;
772       }
773       initialization_status_ = s;
774       std::swap(done_callbacks_, callbacks_to_run);
775     }
776 
777     for (auto&& ctx_done : callbacks_to_run) {
778       ProduceOutput(ctx_done.first, ctx_done.second);
779     }
780     ProduceOutput(ctx, done);
781   }
782 
TryInit(OpKernelContext * ctx,IteratorResource ** iterator,ContainerInfo * cinfo)783   Status TryInit(OpKernelContext* ctx, IteratorResource** iterator,
784                  ContainerInfo* cinfo) {
785     TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def()));
786 
787     FunctionLibraryRuntime* flr;
788     std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
789     std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
790     TF_RETURN_IF_ERROR(
791         ctx->function_library()->Clone(&flib_def, &pflr, &flr, true));
792 
793     // Create an IteratorResource that will hold the iterator for this op.
794     TF_RETURN_IF_ERROR(
795         ctx->resource_manager()->LookupOrCreate<IteratorResource>(
796             cinfo->container(), cinfo->name(), iterator,
797             [ctx, flr, this, &flib_def, &pflr](IteratorResource** ret)
798                 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
799                   *ret = new IteratorResource(
800                       ctx->env(), output_dtypes_, output_shapes_,
801                       /*device_mgr=*/nullptr, std::move(flib_def),
802                       std::move(pflr), flr);
803                   return OkStatus();
804                 }));
805 
806     core::ScopedUnref unref_iterator(*iterator);
807 
808     TF_RETURN_IF_ERROR(
809         VerifyTypesMatch(output_dtypes_, (*iterator)->output_dtypes()));
810     TF_RETURN_IF_ERROR(
811         VerifyShapesCompatible(output_shapes_, (*iterator)->output_shapes()));
812 
813     // Call the dataset_factory_func_ to create a new dataset,
814     // over which this op will iterate.
815     FunctionLibraryRuntime::Handle f_handle;
816     TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
817         dataset_factory_func_.name(), AttrSlice(&dataset_factory_func_.attr()),
818         &f_handle));
819     FunctionLibraryRuntime::Options opts;
820     opts.cancellation_manager = ctx->cancellation_manager();
821     ScopedStepContainer step_container(opts.step_id, [ctx](const string& name) {
822       ctx->resource_manager()->Cleanup(name).IgnoreError();
823     });
824     opts.step_container = &step_container;
825     opts.runner = ctx->runner();
826     opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
827     std::vector<Tensor> return_values;
828     TF_RETURN_IF_ERROR(ctx->function_library()->RunSync(
829         std::move(opts), f_handle, {}, &return_values));
830     if (return_values.size() != 1 || return_values[0].dtype() != DT_VARIANT ||
831         !TensorShapeUtils::IsScalar(return_values[0].shape())) {
832       return errors::InvalidArgument(
833           "The `dataset_factory` function must return "
834           "a single scalar of dtype DT_VARIANT.");
835     }
836 
837     // Create an iterator for the dataset that was created in the
838     // factory function.
839     DatasetBase* dataset;
840     TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
841     TF_RETURN_IF_ERROR((*iterator)->SetIteratorFromDataset(ctx, dataset));
842     (*iterator)->Ref();
843     return OkStatus();
844   }
845 
ProduceOutput(OpKernelContext * ctx,const DoneCallback & done)846   void ProduceOutput(OpKernelContext* ctx, const DoneCallback& done) {
847     Tensor* handle;
848     OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &handle),
849                          done);
850     Status s;
851     {
852       mutex_lock l(mu_);
853       s = initialization_status_;
854       if (s.ok()) {
855         handle->scalar<ResourceHandle>()() =
856             MakeResourceHandle<IteratorResource>(ctx, cinfo_.container(),
857                                                  cinfo_.name());
858       }
859     }
860     OP_REQUIRES_OK_ASYNC(ctx, s, done);
861     done();
862   }
863 
864   NameAttrList dataset_factory_func_;
865   DataTypeVector output_dtypes_;
866   std::vector<PartialTensorShape> output_shapes_;
867 
868   BackgroundWorker background_worker_;
869 
870   mutex mu_;
871   ContainerInfo cinfo_ TF_GUARDED_BY(mu_);
872   IteratorResource* iterator_resource_ TF_GUARDED_BY(mu_) = nullptr;
873 
874   bool initialization_started_ TF_GUARDED_BY(mu_) = false;
875   Status initialization_status_ TF_GUARDED_BY(mu_);
876   std::vector<std::pair<OpKernelContext*, DoneCallback>> done_callbacks_
877       TF_GUARDED_BY(mu_);
878   const int graph_def_version_;
879 };
880 
881 }  // namespace
882 
AsAsync()883 AsyncOpKernel* IteratorGetNextOp::AsAsync() {
884   return type_string() == "IteratorGetNextSync" ? nullptr : this;
885 }
886 
RecordElementSize(const std::vector<Tensor> element,profiler::TraceMe * traceme)887 void RecordElementSize(const std::vector<Tensor> element,
888                        profiler::TraceMe* traceme) {
889   traceme->AppendMetadata([&]() {
890     int64_t element_size = 0;
891     for (const auto& component : element) {
892       element_size += component.TotalBytes();
893     }
894     return profiler::TraceMeEncode({{"element_size", element_size}});
895   });
896 }
897 
DoCompute(OpKernelContext * ctx)898 Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
899   VLOG(3) << "IteratorGetNextOp enter. iter_id=" << ctx->frame_iter().iter_id;
900   auto cleanup = gtl::MakeCleanup([ctx] {
901     VLOG(3) << "IteratorGetNextOp exit. iter_id=" << ctx->frame_iter().iter_id;
902   });
903   activity_watcher::ActivityScope activity_scope([ctx = ctx]() {
904     return activity_watcher::ActivityFromContext(
905         ctx, "IteratorGetNextOp::DoCompute",
906         activity_watcher::ActivityCategory::kDatasetOp);
907   });
908   profiler::TraceMe traceme(
909       [&] {
910         return profiler::TraceMeEncode(
911             "IteratorGetNextOp::DoCompute",
912             {{"id", ctx->step_id()}, {"iter_num", ctx->frame_iter().iter_id}});
913       },
914       profiler::kInfo);
915   tensorflow::ResourceTagger tag(kTFDataResourceTag,
916                                  ctx->op_kernel().type_string());
917   IteratorResource* iterator;
918   TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
919   core::ScopedUnref unref_iterator(iterator);
920   std::vector<Tensor> components;
921   bool end_of_sequence = false;
922 
923   TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &components, &end_of_sequence));
924   if (end_of_sequence) {
925     return errors::OutOfRange("End of sequence");
926   }
927   TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
928   TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
929   RecordElementSize(components, &traceme);
930   for (int i = 0; i < components.size(); ++i) {
931     ctx->set_output(i, components[i]);
932   }
933   return OkStatus();
934 }
935 
DoCompute(OpKernelContext * ctx)936 Status IteratorGetNextAsOptionalOp::DoCompute(OpKernelContext* ctx) {
937   VLOG(3) << "IteratorGetNextAsOptionalOp enter. iter_id="
938           << ctx->frame_iter().iter_id;
939   auto cleanup = gtl::MakeCleanup([ctx] {
940     VLOG(3) << "IteratorGetNextAsOptionalOp exit. iter_id="
941             << ctx->frame_iter().iter_id;
942   });
943   activity_watcher::ActivityScope activity_scope([ctx = ctx]() {
944     return activity_watcher::ActivityFromContext(
945         ctx, "IteratorGetNextAsOptionalOp::DoCompute",
946         activity_watcher::ActivityCategory::kDatasetOp);
947   });
948   profiler::TraceMe traceme(
949       [&] {
950         return profiler::TraceMeEncode(
951             "IteratorGetNextAsOptionalOp::DoCompute",
952             {{"id", ctx->step_id()}, {"iter_num", ctx->frame_iter().iter_id}});
953       },
954       profiler::kInfo);
955   tensorflow::ResourceTagger tag(kTFDataResourceTag,
956                                  ctx->op_kernel().type_string());
957   IteratorResource* iterator;
958   TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
959   core::ScopedUnref unref_iterator(iterator);
960   std::vector<Tensor> components;
961   bool end_of_sequence = false;
962 
963   TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &components, &end_of_sequence));
964 
965   if (end_of_sequence) {
966     return WriteOptionalNoneToOutput(ctx, 0);
967   } else {
968     RecordElementSize(components, &traceme);
969     for (int i = 0; i < components.size(); ++i) {
970       if (components[i].dtype() != output_types_[i]) {
971         return errors::InvalidArgument(
972             "The given optional does not match the expected type for "
973             "component ",
974             i, ". Expected: ", DataTypeString(output_types_[i]),
975             ". Actual: ", DataTypeString(components[i].dtype()), ".");
976       }
977       if (!output_shapes_[i].IsCompatibleWith(components[i].shape())) {
978         return errors::InvalidArgument(
979             "The given optional does not match the expected shape "
980             "for component ",
981             i, ". Expected: ", output_shapes_[i].DebugString(),
982             ". Actual: ", components[i].shape().DebugString(), ".");
983       }
984     }
985     return WriteOptionalWithValueToOutput(ctx, 0, std::move(components));
986   }
987 }
988 
Compute(OpKernelContext * ctx)989 void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) {
990   const Tensor& resource_handle_t = ctx->input(0);
991   OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
992               errors::InvalidArgument("resource_handle must be a scalar"));
993 
994   // Validate that the handle corresponds to a real resource, and
995   // that it is an IteratorResource.
996   IteratorResource* iterator_resource;
997   OP_REQUIRES_OK(
998       ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
999   iterator_resource->Unref();
1000 
1001   Tensor* string_handle_t;
1002   OP_REQUIRES_OK(ctx,
1003                  ctx->allocate_output(0, TensorShape({}), &string_handle_t));
1004   string_handle_t->scalar<tstring>()() =
1005       resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
1006 }
1007 
IteratorFromStringHandleOp(OpKernelConstruction * ctx)1008 IteratorFromStringHandleOp::IteratorFromStringHandleOp(
1009     OpKernelConstruction* ctx)
1010     : OpKernel(ctx) {
1011   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
1012   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
1013   OP_REQUIRES(
1014       ctx,
1015       output_dtypes_.empty() || output_shapes_.empty() ||
1016           output_dtypes_.size() == output_shapes_.size(),
1017       errors::InvalidArgument("If both 'output_types' and 'output_shapes' "
1018                               "are set, they must have the same length."));
1019 }
1020 
Compute(OpKernelContext * ctx)1021 void IteratorFromStringHandleOp::Compute(OpKernelContext* ctx) {
1022   const Tensor& string_handle_t = ctx->input(0);
1023   OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()),
1024               errors::InvalidArgument("string_handle must be a scalar"));
1025 
1026   ResourceHandle resource_handle;
1027   OP_REQUIRES(
1028       ctx, resource_handle.ParseFromString(string_handle_t.scalar<tstring>()()),
1029       errors::InvalidArgument(
1030           "Could not parse string_handle as a valid ResourceHandle"));
1031 
1032   OP_REQUIRES(
1033       ctx, resource_handle.device() == ctx->device()->attributes().name(),
1034       errors::InvalidArgument("Attempted create an iterator on device \"",
1035                               ctx->device()->attributes().name(),
1036                               "\" from handle defined on device \"",
1037                               resource_handle.device(), "\""));
1038 
1039   // Validate that the handle corresponds to a real resource, and
1040   // that it is an IteratorResource.
1041   IteratorResource* iterator_resource;
1042   OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &iterator_resource));
1043   core::ScopedUnref unref_iterator(iterator_resource);
1044   if (!output_dtypes_.empty()) {
1045     OP_REQUIRES_OK(ctx, VerifyTypesMatch(output_dtypes_,
1046                                          iterator_resource->output_dtypes()));
1047   }
1048   if (!output_shapes_.empty()) {
1049     OP_REQUIRES_OK(ctx,
1050                    VerifyShapesCompatible(output_shapes_,
1051                                           iterator_resource->output_shapes()));
1052   }
1053 
1054   Tensor* resource_handle_t;
1055   OP_REQUIRES_OK(ctx,
1056                  ctx->allocate_output(0, TensorShape({}), &resource_handle_t));
1057   resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
1058 }
1059 
SerializeIteratorOp(OpKernelConstruction * ctx)1060 SerializeIteratorOp::SerializeIteratorOp(OpKernelConstruction* ctx)
1061     : OpKernel(ctx) {
1062   if (ctx->HasAttr(kExternalStatePolicy)) {
1063     int64_t state_change_option;
1064     OP_REQUIRES_OK(ctx,
1065                    ctx->GetAttr(kExternalStatePolicy, &state_change_option));
1066     external_state_policy_ =
1067         SerializationContext::ExternalStatePolicy(state_change_option);
1068   }
1069 }
1070 
Compute(OpKernelContext * ctx)1071 void SerializeIteratorOp::Compute(OpKernelContext* ctx) {
1072   tensorflow::ResourceTagger tag(kTFDataResourceTag,
1073                                  ctx->op_kernel().type_string());
1074   const Tensor& resource_handle_t = ctx->input(0);
1075   OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
1076               errors::InvalidArgument("resource_handle must be a scalar"));
1077   // Validate that the handle corresponds to a real resource, and
1078   // that it is an IteratorResource.
1079   IteratorResource* iterator_resource;
1080   OP_REQUIRES_OK(
1081       ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
1082   core::ScopedUnref unref_iterator(iterator_resource);
1083   IteratorVariantSerializer serializer;
1084   SerializationContext::Params params(ctx);
1085   params.external_state_policy = external_state_policy_;
1086   SerializationContext serialization_ctx(params);
1087   OP_REQUIRES_OK(ctx, serializer.InitializeFromIterator(&serialization_ctx,
1088                                                         iterator_resource));
1089   Tensor* serialized_t;
1090   OP_REQUIRES_OK(ctx,
1091                  ctx->allocate_output(0, TensorShape({serializer.NumTensors()}),
1092                                       &serialized_t));
1093   OP_REQUIRES_OK(ctx, serializer.Serialize(serialized_t));
1094 }
1095 
Compute(OpKernelContext * ctx)1096 void DeserializeIteratorOp::Compute(OpKernelContext* ctx) {
1097   tensorflow::ResourceTagger tag(kTFDataResourceTag,
1098                                  ctx->op_kernel().type_string());
1099   // Validate that the handle corresponds to a real resource, and
1100   // that it is an IteratorResource.
1101   IteratorResource* iterator_resource;
1102   OP_REQUIRES_OK(
1103       ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
1104   core::ScopedUnref unref_iterator(iterator_resource);
1105   const Tensor* serialized_t;
1106   OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized_t));
1107   IteratorVariantSerializer serializer;
1108   OP_REQUIRES_OK(ctx, serializer.InitFromTensor(serialized_t));
1109   Status s = iterator_resource->Restore(ctx, serializer.GetReader());
1110   if (!s.ok()) {
1111     OP_REQUIRES_OK(
1112         ctx,
1113         errors::CreateWithUpdatedMessage(
1114             s, absl::StrCat(
1115                    "Failed to restore dataset iterator from checkpoint: ",
1116                    s.error_message(),
1117                    ". Make sure the dataset definition has not changed between "
1118                    "the process that saved the checkpoint and the process that "
1119                    "is restoring it.")));
1120   }
1121 }
1122 
1123 namespace {
1124 
1125 REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
1126 REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_CPU).Priority(2),
1127                         IteratorHandleOp);
1128 REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_GPU).Priority(1),
1129                         IteratorHandleOp);
1130 REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU).Priority(2),
1131                         MakeIteratorOp);
1132 REGISTER_KERNEL_BUILDER(
1133     Name("MakeIterator").Device(DEVICE_GPU).Priority(1).HostMemory("dataset"),
1134     MakeIteratorOp);
1135 REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE_CPU).Priority(2),
1136                         DeleteIteratorOp);
1137 REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE_GPU).Priority(1),
1138                         DeleteIteratorOp);
1139 REGISTER_KERNEL_BUILDER(
1140     Name("AnonymousIterator").Device(DEVICE_CPU).Priority(2),
1141     AnonymousIteratorHandleOp);
1142 REGISTER_KERNEL_BUILDER(
1143     Name("AnonymousIterator").Device(DEVICE_GPU).Priority(1),
1144     AnonymousIteratorHandleOp);
1145 REGISTER_KERNEL_BUILDER(
1146     Name("AnonymousIteratorV2").Device(DEVICE_CPU).Priority(2),
1147     AnonymousIteratorHandleOp);
1148 REGISTER_KERNEL_BUILDER(
1149     Name("AnonymousIteratorV2").Device(DEVICE_GPU).Priority(1),
1150     AnonymousIteratorHandleOp);
1151 REGISTER_KERNEL_BUILDER(
1152     Name("AnonymousIteratorV3").Device(DEVICE_CPU).Priority(2),
1153     AnonymousIteratorHandleOp);
1154 REGISTER_KERNEL_BUILDER(
1155     Name("AnonymousIteratorV3").Device(DEVICE_GPU).Priority(1),
1156     AnonymousIteratorHandleOp);
1157 REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
1158                         ToSingleElementOp);
1159 REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
1160                         OneShotIteratorOp);
1161 REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU).Priority(2),
1162                         IteratorGetNextOp);
1163 REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_GPU).Priority(1),
1164                         IteratorGetNextOp);
1165 REGISTER_KERNEL_BUILDER(
1166     Name("IteratorGetNextSync").Device(DEVICE_CPU).Priority(2),
1167     IteratorGetNextOp);
1168 REGISTER_KERNEL_BUILDER(
1169     Name("IteratorGetNextSync").Device(DEVICE_GPU).Priority(1),
1170     IteratorGetNextOp);
1171 REGISTER_KERNEL_BUILDER(
1172     Name("IteratorGetNextAsOptional").Device(DEVICE_CPU).Priority(2),
1173     IteratorGetNextAsOptionalOp);
1174 REGISTER_KERNEL_BUILDER(
1175     Name("IteratorGetNextAsOptional").Device(DEVICE_GPU).Priority(1),
1176     IteratorGetNextAsOptionalOp);
1177 REGISTER_KERNEL_BUILDER(
1178     Name("IteratorToStringHandle").Device(DEVICE_CPU).Priority(2),
1179     IteratorToStringHandleOp);
1180 REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle")
1181                             .Device(DEVICE_GPU)
1182                             .HostMemory("string_handle")
1183                             .Priority(1),
1184                         IteratorToStringHandleOp);
1185 REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU),
1186                         IteratorFromStringHandleOp);
1187 REGISTER_KERNEL_BUILDER(
1188     Name("IteratorFromStringHandleV2").Device(DEVICE_CPU).Priority(2),
1189     IteratorFromStringHandleOp);
1190 REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2")
1191                             .Device(DEVICE_GPU)
1192                             .HostMemory("string_handle")
1193                             .Priority(1),
1194                         IteratorFromStringHandleOp);
1195 REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
1196                         SerializeIteratorOp);
1197 REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
1198                         DeserializeIteratorOp);
1199 
1200 }  // namespace
1201 
1202 }  // namespace data
1203 }  // namespace tensorflow
1204