1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/iterator_ops.h"
16
17 #include <cstdint>
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <utility>
22
23 #include "absl/memory/memory.h"
24 #include "absl/time/time.h"
25 #include "tensorflow/core/activity_watcher/activity.h"
26 #include "tensorflow/core/activity_watcher/activity_utils.h"
27 #include "tensorflow/core/common_runtime/graph_constructor.h"
28 #include "tensorflow/core/common_runtime/graph_runner.h"
29 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
30 #include "tensorflow/core/common_runtime/renamed_device.h"
31 #include "tensorflow/core/common_runtime/threadpool_device.h"
32 #include "tensorflow/core/data/captured_function.h"
33 #include "tensorflow/core/data/dataset_utils.h"
34 #include "tensorflow/core/data/finalization_utils.h"
35 #include "tensorflow/core/data/metric_utils.h"
36 #include "tensorflow/core/data/root_dataset.h"
37 #include "tensorflow/core/data/serialization_utils.h"
38 #include "tensorflow/core/data/utils.h"
39 #include "tensorflow/core/framework/cancellation.h"
40 #include "tensorflow/core/framework/function.h"
41 #include "tensorflow/core/framework/metrics.h"
42 #include "tensorflow/core/framework/op_kernel.h"
43 #include "tensorflow/core/framework/partial_tensor_shape.h"
44 #include "tensorflow/core/framework/resource_op_kernel.h"
45 #include "tensorflow/core/framework/stats_aggregator.h"
46 #include "tensorflow/core/framework/tensor.h"
47 #include "tensorflow/core/framework/types.h"
48 #include "tensorflow/core/framework/variant_op_registry.h"
49 #include "tensorflow/core/framework/variant_tensor_data.h"
50 #include "tensorflow/core/kernels/data/optional_ops.h"
51 #include "tensorflow/core/kernels/ops_util.h"
52 #include "tensorflow/core/lib/core/errors.h"
53 #include "tensorflow/core/lib/core/refcount.h"
54 #include "tensorflow/core/lib/core/threadpool.h"
55 #include "tensorflow/core/lib/gtl/cleanup.h"
56 #include "tensorflow/core/lib/random/random.h"
57 #include "tensorflow/core/lib/strings/strcat.h"
58 #include "tensorflow/core/lib/strings/stringprintf.h"
59 #include "tensorflow/core/platform/casts.h"
60 #include "tensorflow/core/platform/env.h"
61 #include "tensorflow/core/platform/errors.h"
62 #include "tensorflow/core/platform/mem.h"
63 #include "tensorflow/core/platform/mutex.h"
64 #include "tensorflow/core/platform/refcount.h"
65 #include "tensorflow/core/platform/resource.h"
66 #include "tensorflow/core/profiler/lib/traceme.h"
67 #include "tensorflow/core/profiler/lib/traceme_encode.h"
68 #include "tensorflow/core/public/session_options.h"
69
70 namespace tensorflow {
71 namespace data {
72 namespace {
73
74 // See documentation in ../../ops/dataset_ops.cc for a high-level
75 // description of the following ops.
76
77 const char kAnonymousIterator[] = "AnonymousIterator";
78 const char kAnonymousIteratorV2[] = "AnonymousIteratorV2";
79 const char kAnonymousIteratorV3[] = "AnonymousIteratorV3";
80 const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
81 const char kOutputShapes[] = "output_shapes";
82 const char kOutputTypes[] = "output_types";
83
84 } // namespace
85
86 /* static */ constexpr const char* const
87 SerializeIteratorOp::kExternalStatePolicy;
88
IteratorResource(Env * env,const DataTypeVector & output_dtypes,const std::vector<PartialTensorShape> & output_shapes,std::unique_ptr<DeviceMgr> device_mgr,std::unique_ptr<FunctionLibraryDefinition> flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * flr)89 IteratorResource::IteratorResource(
90 Env* env, const DataTypeVector& output_dtypes,
91 const std::vector<PartialTensorShape>& output_shapes,
92 std::unique_ptr<DeviceMgr> device_mgr,
93 std::unique_ptr<FunctionLibraryDefinition> flib_def,
94 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
95 FunctionLibraryRuntime* flr)
96 : metrics_collector_(flr->device()->device_type(), *env),
97 unbounded_thread_pool_(env, "tf_data_iterator_resource"),
98 device_mgr_(std::move(device_mgr)),
99 iterator_state_(std::make_shared<State>(std::move(flib_def),
100 std::move(pflr), flr,
101 /*iterator=*/nullptr)),
102 output_dtypes_(output_dtypes),
103 output_shapes_(output_shapes) {
104 VLOG(2) << "creating iterator resource";
105 }
106
~IteratorResource()107 IteratorResource::~IteratorResource() {
108 VLOG(2) << "destroying iterator resource";
109 }
110
GetNext(OpKernelContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)111 Status IteratorResource::GetNext(OpKernelContext* ctx,
112 std::vector<Tensor>* out_tensors,
113 bool* end_of_sequence) {
114 std::shared_ptr<State> captured_state;
115 {
116 tf_shared_lock l(mu_);
117 captured_state = iterator_state_;
118 }
119 if (!captured_state->iterator()) {
120 return errors::FailedPrecondition(
121 "GetNext() failed because the iterator has not been initialized. "
122 "Ensure that you have run the initializer operation for this iterator "
123 "before getting the next element.");
124 }
125 IteratorContext::Params params(ctx);
126 params.flr = captured_state->flr();
127 params.function_handle_cache = captured_state->function_handle_cache();
128 params.resource_mgr = captured_state->resource_mgr();
129 params.thread_factory = unbounded_thread_pool_.get_thread_factory();
130 params.thread_pool = &unbounded_thread_pool_;
131 params.cancellation_manager = captured_state->cancellation_manager();
132 std::function<void()> deregister_fn;
133 TF_RETURN_IF_ERROR(RegisterCancellationCallback(
134 ctx->cancellation_manager(),
135 [cm = params.cancellation_manager]() { cm->StartCancel(); },
136 &deregister_fn));
137 auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
138
139 const absl::Time start_time = metrics_collector_.RecordStart();
140 auto iterator_ = captured_state->iterator();
141 auto status = iterator_->GetNext(IteratorContext(std::move(params)),
142 out_tensors, end_of_sequence);
143 metrics_collector_.RecordStop(start_time, *out_tensors);
144 return status;
145 }
146
Save(SerializationContext * ctx,IteratorStateWriter * writer)147 Status IteratorResource::Save(SerializationContext* ctx,
148 IteratorStateWriter* writer) {
149 std::shared_ptr<State> captured_state;
150 {
151 tf_shared_lock l(mu_);
152 captured_state = iterator_state_;
153 }
154 auto iterator_ = captured_state->iterator();
155 if (iterator_) {
156 return iterator_->Save(ctx, writer);
157 }
158 return errors::FailedPrecondition(
159 "Save() failed because the iterator has not been initialized. Ensure "
160 "that you have run the initializer operation for this iterator before "
161 "saving it.");
162 }
163
Restore(OpKernelContext * ctx,IteratorStateReader * reader)164 Status IteratorResource::Restore(OpKernelContext* ctx,
165 IteratorStateReader* reader) {
166 const DatasetBase* dataset;
167 std::shared_ptr<State> new_state;
168 const DatasetBase* input_dataset;
169 {
170 tf_shared_lock l(mu_);
171 if (!iterator_state_->iterator()) {
172 return errors::FailedPrecondition(
173 "Restore() failed because the iterator has not been initialized. "
174 "Ensure that you have run the initializer operation for this "
175 "iterator before restoring it.");
176 }
177 auto iterator_ = iterator_state_->iterator();
178 dataset = iterator_->dataset();
179 // Hang onto a reference until we've created the new iterator, which will
180 // then hold its own reference to keep the dataset alive.
181 dataset->Ref();
182 new_state =
183 std::make_shared<State>(iterator_state_->flib_def(),
184 iterator_state_->pflr(), iterator_state_->flr(),
185 /*iterator=*/nullptr);
186 input_dataset = iterator_state_->dataset();
187 }
188 core::ScopedUnref scoped_unref(dataset);
189 IteratorContext::Params params(ctx);
190 params.flr = new_state->flr();
191 params.function_handle_cache = new_state->function_handle_cache();
192 params.resource_mgr = new_state->resource_mgr();
193 params.thread_factory = unbounded_thread_pool_.get_thread_factory();
194 params.thread_pool = &unbounded_thread_pool_;
195 params.cancellation_manager = new_state->cancellation_manager();
196 std::function<void()> deregister_fn;
197 TF_RETURN_IF_ERROR(RegisterCancellationCallback(
198 ctx->cancellation_manager(),
199 [cm = params.cancellation_manager]() { cm->StartCancel(); },
200 &deregister_fn));
201 auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
202 std::unique_ptr<IteratorBase> iterator_base;
203 TF_RETURN_IF_ERROR(dataset->MakeIteratorFromCheckpoint(
204 IteratorContext(std::move(params)), "Iterator", reader, &iterator_base));
205 new_state->DowncastAndSetIteratorAndDataset(std::move(iterator_base),
206 input_dataset);
207
208 mutex_lock l(mu_);
209 std::swap(iterator_state_, new_state);
210 return OkStatus();
211 }
212
SetIteratorFromDataset(OpKernelContext * ctx,const DatasetBase * dataset)213 Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
214 const DatasetBase* dataset) {
215 std::shared_ptr<State> new_state;
216 {
217 tf_shared_lock l(mu_);
218 new_state =
219 std::make_shared<State>(iterator_state_->flib_def(),
220 iterator_state_->pflr(), iterator_state_->flr(),
221 /*iterator=*/nullptr);
222 }
223
224 // Create new iterator.
225 IteratorContext::Params params(ctx);
226 params.flr = new_state->flr();
227 params.function_handle_cache = new_state->function_handle_cache();
228 params.resource_mgr = new_state->resource_mgr();
229 params.thread_factory = unbounded_thread_pool_.get_thread_factory();
230 params.thread_pool = &unbounded_thread_pool_;
231 params.cancellation_manager = new_state->cancellation_manager();
232 std::function<void()> deregister_fn;
233 TF_RETURN_IF_ERROR(RegisterCancellationCallback(
234 ctx->cancellation_manager(),
235 [cm = params.cancellation_manager]() { cm->StartCancel(); },
236 &deregister_fn));
237 auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
238
239 std::unique_ptr<IteratorBase> iterator;
240 if (ctx->function_library()->device()->device_type() == DEVICE_CPU) {
241 DatasetBase* finalized_dataset;
242 TF_ASSIGN_OR_RETURN(finalized_dataset, GetFinalizedDataset(ctx, dataset));
243 TF_RETURN_IF_ERROR(finalized_dataset->MakeIterator(
244 IteratorContext(std::move(params)),
245 /*parent=*/nullptr, "Iterator", &iterator));
246 } else {
247 TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
248 /*parent=*/nullptr, "Iterator",
249 &iterator));
250 }
251 TF_RETURN_IF_ERROR(
252 VerifyTypesMatch(output_dtypes_, iterator->output_dtypes()));
253 TF_RETURN_IF_ERROR(
254 VerifyShapesCompatible(output_shapes_, iterator->output_shapes()));
255
256 new_state->DowncastAndSetIteratorAndDataset(std::move(iterator), dataset);
257
258 mutex_lock l(mu_);
259 std::swap(iterator_state_, new_state);
260 return OkStatus();
261 }
262
263 namespace {
264
265 // Wrapper for encoding/decoding the iterator state stored in a Variant tensor.
266 // The get() method returns an VariantTensorData object which contains all the
267 // state needed to restore a single iterator.
268 //
269 // Usage example:
270 //
271 // Encoding:
272 //
273 // Tensor t(DT_VARIANT, TensorShape({}));
274 // t->scalar<Variant>()() = IteratorStateVariant();
275 //
276 // Encode() sets the type_name of the VariantTensorData object to
277 // IteratorStateVariant::TypeName().
278 //
279 // Decoding:
280 //
281 // Variant v = <VariantTensorDataProto object>;
282 // DecodeUnaryVariant(&v);
283 // IteratorStateVariant* wrapper = v.get<IteratorStateVariant>();
284 // IteratorStateReader reader({wrapper->GetData()});
285 // iterator_resource->Restore(ctx, &reader);
286 //
287 // The type_name of the VariantTensorData object to be decoded must
288 // match IteratorStateVariant::TypeName().
289 class IteratorStateVariant {
290 public:
IteratorStateVariant()291 IteratorStateVariant() : data_(nullptr) {}
IteratorStateVariant(const IteratorStateVariant & other)292 IteratorStateVariant(const IteratorStateVariant& other) : data_(nullptr) {
293 if (other.data_) {
294 Decode(*other.data_);
295 }
296 }
297 IteratorStateVariant& operator=(IteratorStateVariant&& other) = default;
298 IteratorStateVariant& operator=(const IteratorStateVariant& other) = delete;
299
300 // Initializes `this` from a VariantTensorData object.
InitializeFromVariantData(std::unique_ptr<VariantTensorData> d)301 Status InitializeFromVariantData(std::unique_ptr<VariantTensorData> d) {
302 data_ = std::move(d);
303 return OkStatus();
304 }
305
TypeName() const306 string TypeName() const { return kIteratorVariantTypeName; }
Encode(VariantTensorData * data) const307 void Encode(VariantTensorData* data) const { *data = *data_; }
Decode(VariantTensorData data)308 bool Decode(VariantTensorData data) {
309 if (data.type_name() != TypeName()) {
310 return false;
311 }
312 auto tensor_data = std::make_unique<VariantTensorData>();
313 std::swap(*tensor_data, data);
314 data_ = std::move(tensor_data);
315 return true;
316 }
317
318 // Returns a borrowed pointer to the underlying VariantTensorData.
GetData() const319 const VariantTensorData* GetData() const { return data_.get(); }
320
DebugString() const321 string DebugString() const {
322 if (data_) {
323 return strings::StrCat("IteratorStateVariant<", data_->DebugString(),
324 ">");
325 } else {
326 return strings::StrCat("IteratorStateVariant<empty>");
327 }
328 }
329
330 private:
331 std::unique_ptr<VariantTensorData> data_;
332 };
333
334 // Register the reader class in the global variant decode_fn registry
335 // so that a Variant containing a serialized representation of iterator state
336 // can be decoded using DecodeUnaryVariant. If we don't do this we will need
337 // to manually decode the returned Variant using MaybeDecodeAndCopy in
338 // DeserializeIteratorOp which is not recommended.
339 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
340 kIteratorVariantTypeName);
341
342 // A helper class that uses a list of IteratorStateVariant objects to represent
343 // the state for an iterator resource. It exposes methods that help with
344 // saving and restoring of this state. Sample usage
345 // Saving:
346 // IteratorVariantSerializer serializer;
347 // serializer.InitializeFromIterator(iterator_resource);
348 // Tensor serialized_t;
349 // serializer.Serialize(&serialized_t);
350 //
351 // Restoring:
352 // IteratorVariantSerializer serializer;
353 // serializer.InitFromTensor(ctx->input(0));
354 // IteratorStateReader* reader = serializer.GetReader();
355 // iterator_resource->Restore(ctx, reader);
356 class IteratorVariantSerializer {
357 public:
IteratorVariantSerializer()358 IteratorVariantSerializer() {}
359
360 // Calls `Save` on the iterator_resource to build up the list of
361 // IteratorStateVariant objects.
InitializeFromIterator(SerializationContext * serialization_ctx,IteratorResource * iterator_resource)362 Status InitializeFromIterator(SerializationContext* serialization_ctx,
363 IteratorResource* iterator_resource) {
364 VariantTensorDataWriter writer;
365 TF_RETURN_IF_ERROR(iterator_resource->Save(serialization_ctx, &writer));
366 std::vector<std::unique_ptr<VariantTensorData>> data;
367 writer.ReleaseData(&data);
368 variants_.clear();
369 variants_.reserve(data.size());
370 for (auto& it : data) {
371 IteratorStateVariant v;
372 TF_RETURN_IF_ERROR(v.InitializeFromVariantData(std::move(it)));
373 variants_.push_back(v);
374 }
375 num_tensors_ = variants_.size();
376 can_serialize_ = true;
377 return OkStatus();
378 }
379
380 // Initializes `this` from `serialized_t` while restoring the iterator state.
InitFromTensor(const Tensor * serialized_t)381 Status InitFromTensor(const Tensor* serialized_t) {
382 int64_t num_tensors = serialized_t->dim_size(0);
383 auto serialized_vec = serialized_t->vec<Variant>();
384 std::vector<const VariantTensorData*> data;
385 data.reserve(num_tensors);
386 for (int i = 0; i < num_tensors; ++i) {
387 auto* w = serialized_vec(i).get<IteratorStateVariant>();
388 if (!w) {
389 return errors::Internal(
390 "Cannot initialize an iterator from tensor ",
391 serialized_vec(i).DebugString(),
392 ". Expected a variant tensor of type IteratorStateVariant");
393 }
394 data.push_back(w->GetData());
395 }
396 reader_ = std::make_unique<VariantTensorDataReader>(data);
397 num_tensors_ = data.size();
398 return OkStatus();
399 }
400
NumTensors()401 int64_t NumTensors() { return num_tensors_; }
402
403 // Stores the IteratorStateVariant list into a pre-allocated tensor. Expects
404 // that InitializeFromIterator was called before.
Serialize(Tensor * serialized)405 Status Serialize(Tensor* serialized) {
406 if (!can_serialize_) {
407 return errors::InvalidArgument(
408 "Please call InitializeFromIterator before calling Serialize.");
409 }
410 int64_t size = variants_.size();
411 for (int64_t i = 0; i < size; ++i) {
412 if (variants_[i].GetData() == nullptr) {
413 return errors::Internal(
414 "Cannot serialize an empty IteratorStateVariant");
415 }
416 serialized->vec<Variant>()(i) = variants_[i];
417 }
418 return OkStatus();
419 }
420
421 // Returns an IteratorStateReader to restore iterator state. Expects that
422 // InitFromTensor was called before.
GetReader()423 IteratorStateReader* GetReader() { return reader_.get(); }
424
425 private:
426 bool can_serialize_ = false;
427 int64_t num_tensors_;
428 std::vector<IteratorStateVariant> variants_;
429 std::unique_ptr<IteratorStateReader> reader_;
430 };
431
432 } // namespace
433
434 // Note that IteratorHandleOp holds a reference to the resource it creates. If
435 // cleaning up resources with DestroyResourceOp is important, consider creating
436 // resource containers with AnonymousIteratorHandleOp instead.
IteratorHandleOp(OpKernelConstruction * ctx)437 IteratorHandleOp::IteratorHandleOp(OpKernelConstruction* ctx)
438 : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
439 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
440 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
441 OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
442 }
443
444 // The resource is deleted from the resource manager only when it is private
445 // to kernel. Ideally the resource should be deleted when it is no longer held
446 // by anyone, but it would break backward compatibility.
~IteratorHandleOp()447 IteratorHandleOp::~IteratorHandleOp() {
448 if (resource_ != nullptr) {
449 resource_->Unref();
450 if (cinfo_.resource_is_private_to_kernel()) {
451 if (!cinfo_.resource_manager()
452 ->template Delete<IteratorResource>(cinfo_.container(),
453 cinfo_.name())
454 .ok()) {
455 // Do nothing; the resource can have been deleted by session resets.
456 }
457 }
458 }
459 }
460
Compute(OpKernelContext * context)461 void IteratorHandleOp::Compute(OpKernelContext* context)
462 TF_LOCKS_EXCLUDED(mu_) {
463 {
464 mutex_lock l(mu_);
465 if (resource_ == nullptr) {
466 FunctionLibraryRuntime* flr;
467 std::unique_ptr<DeviceMgr> device_mgr(nullptr);
468 std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
469 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
470 // If the iterator is shared then we construct a new FLR, and pass that
471 // in. NOTE(mrry,rohanj): In this case it is not possible to call remote
472 // functions from the iterator. We may add this functionality if there
473 // is sufficient demand, but it will require a significant refactoring.
474 if (!name_.empty()) {
475 flr = CreatePrivateFLR(context, &device_mgr, &flib_def, &pflr);
476 } else {
477 OP_REQUIRES_OK(context, context->function_library()->Clone(
478 &flib_def, &pflr, &flr, true));
479 }
480
481 ResourceMgr* mgr = context->resource_manager();
482 OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
483
484 IteratorResource* resource;
485 OP_REQUIRES_OK(
486 context,
487 mgr->LookupOrCreate<IteratorResource>(
488 cinfo_.container(), cinfo_.name(), &resource,
489 [context, flr, &device_mgr, &flib_def, &pflr,
490 this](IteratorResource** ret) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
491 *ret = new IteratorResource(
492 context->env(), output_dtypes_, output_shapes_,
493 std::move(device_mgr), std::move(flib_def), std::move(pflr),
494 flr);
495 return OkStatus();
496 }));
497
498 Status s = VerifyResource(resource);
499 if (TF_PREDICT_FALSE(!s.ok())) {
500 resource->Unref();
501 context->SetStatus(s);
502 return;
503 }
504
505 resource_ = resource;
506 }
507 }
508 OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
509 context, 0, cinfo_.container(), cinfo_.name(),
510 TypeIndex::Make<IteratorResource>()));
511 }
512
VerifyResource(IteratorResource * resource)513 Status IteratorHandleOp::VerifyResource(IteratorResource* resource) {
514 TF_RETURN_IF_ERROR(
515 VerifyTypesMatch(output_dtypes_, resource->output_dtypes()));
516 TF_RETURN_IF_ERROR(
517 VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
518 return OkStatus();
519 }
520
CreatePrivateFLR(OpKernelContext * ctx,std::unique_ptr<DeviceMgr> * device_mgr,std::unique_ptr<FunctionLibraryDefinition> * flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> * pflr)521 FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR(
522 OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr,
523 std::unique_ptr<FunctionLibraryDefinition>* flib_def,
524 std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr) {
525 // Wrap the existing device in order to see any captured resources
526 // in its resource manager. The existing device will outlive the
527 // IteratorResource, because we are storing the IteratorResource
528 // in that device's resource manager.
529
530 *device_mgr =
531 std::make_unique<StaticDeviceMgr>(RenamedDevice::NewRenamedDevice(
532 ctx->device()->name(), down_cast<Device*>(ctx->device()),
533 false /* owns_underlying */, false /* isolate_session_state */));
534 *flib_def = std::make_unique<FunctionLibraryDefinition>(
535 *ctx->function_library()->GetFunctionLibraryDefinition());
536 const auto* config = ctx->function_library()->config_proto();
537 *pflr = std::make_unique<ProcessFunctionLibraryRuntime>(
538 device_mgr->get(), ctx->env(),
539 /*config=*/config, graph_def_version_, flib_def->get(),
540 config->graph_options().optimizer_options());
541
542 return (*pflr)->GetFLR(ctx->device()->name());
543 }
544
545 // Like IteratorHandleOp, but creates handles which are never shared, and does
546 // not hold a reference to these handles. The latter is important for eager
547 // execution, since OpKernel instances generally live as long as the program
548 // running them.
AnonymousIteratorHandleOp(OpKernelConstruction * context)549 AnonymousIteratorHandleOp::AnonymousIteratorHandleOp(
550 OpKernelConstruction* context)
551 : AnonymousResourceOp<IteratorResource>(
552 context,
553 /* ref_counting */
554 // Always disable it for TfLite environment and let ResourceMgr
555 // release it at the end of execution because the TfLite
556 // kernel will hold this resource tensor till the end.
557 false,
558 /* return_deleter */
559 // Alwasy disable it for TfLite environment and let ResourceMgr
560 // release it at the end of exectuion because the TfLite
561 // kernel will hold this resource tensor till the end.
562 false),
563 graph_def_version_(context->graph_def_version()) {
564 OP_REQUIRES_OK(context, context->GetAttr(kOutputTypes, &output_dtypes_));
565 OP_REQUIRES_OK(context, context->GetAttr(kOutputShapes, &output_shapes_));
566 }
567
name()568 string AnonymousIteratorHandleOp::name() { return kAnonymousIterator; }
569
CreateResource(OpKernelContext * ctx,std::unique_ptr<FunctionLibraryDefinition> flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * lib,IteratorResource ** resource)570 Status AnonymousIteratorHandleOp::CreateResource(
571 OpKernelContext* ctx, std::unique_ptr<FunctionLibraryDefinition> flib_def,
572 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
573 FunctionLibraryRuntime* lib, IteratorResource** resource) {
574 std::unique_ptr<DeviceMgr> device_mgr(nullptr);
575 *resource = new IteratorResource(ctx->env(), output_dtypes_, output_shapes_,
576 std::move(device_mgr), std::move(flib_def),
577 std::move(pflr), lib);
578 return OkStatus();
579 }
580
HybridAsyncOpKernel(OpKernelConstruction * ctx,const char * background_worker_name)581 HybridAsyncOpKernel::HybridAsyncOpKernel(OpKernelConstruction* ctx,
582 const char* background_worker_name)
583 : AsyncOpKernel(ctx),
584 background_worker_(ctx->env(), background_worker_name) {}
585
ComputeAsync(OpKernelContext * ctx,DoneCallback done)586 void HybridAsyncOpKernel::ComputeAsync(OpKernelContext* ctx,
587 DoneCallback done) {
588 background_worker_.Schedule([this, ctx, done = std::move(done)]() {
589 ctx->SetStatus(DoCompute(ctx));
590 done();
591 });
592 }
593
Compute(OpKernelContext * ctx)594 void HybridAsyncOpKernel::Compute(OpKernelContext* ctx) {
595 ctx->SetStatus(DoCompute(ctx));
596 }
597
DoCompute(OpKernelContext * ctx)598 Status MakeIteratorOp::DoCompute(OpKernelContext* ctx) {
599 tensorflow::ResourceTagger tag(kTFDataResourceTag,
600 ctx->op_kernel().type_string());
601 DatasetBase* dataset;
602 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
603 IteratorResource* iterator_resource;
604 TF_RETURN_IF_ERROR(
605 LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
606 core::ScopedUnref unref_iterator(iterator_resource);
607 return iterator_resource->SetIteratorFromDataset(ctx, dataset);
608 }
609
DoCompute(OpKernelContext * ctx)610 Status DeleteIteratorOp::DoCompute(OpKernelContext* ctx) {
611 tensorflow::ResourceTagger tag(kTFDataResourceTag,
612 ctx->op_kernel().type_string());
613 const ResourceHandle& handle = ctx->input(0).flat<ResourceHandle>()(0);
614 // The iterator resource is guaranteed to exist because the variant tensor
615 // wrapping the deleter is provided as an unused input to this op, which
616 // guarantees that it has not run yet.
617 return DeleteResource(ctx, handle);
618 }
619
620 namespace {
621
622 class ToSingleElementOp : public AsyncOpKernel {
623 public:
ToSingleElementOp(OpKernelConstruction * ctx)624 explicit ToSingleElementOp(OpKernelConstruction* ctx)
625 : AsyncOpKernel(ctx),
626 metrics_collector_(ctx->device()->attributes().device_type(),
627 *ctx->env()),
628 unbounded_threadpool_(ctx->env(), "tf_data_to_single_element") {
629 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
630 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
631 }
632
ComputeAsync(OpKernelContext * ctx,DoneCallback done)633 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
634 unbounded_threadpool_.Schedule([this, ctx, done = std::move(done)]() {
635 ctx->SetStatus(DoCompute(ctx));
636 done();
637 });
638 }
639
Compute(OpKernelContext * ctx)640 void Compute(OpKernelContext* ctx) override {
641 ctx->SetStatus(DoCompute(ctx));
642 }
643
644 private:
DoCompute(OpKernelContext * ctx)645 Status DoCompute(OpKernelContext* ctx) {
646 profiler::TraceMe traceme(
647 [&] {
648 return profiler::TraceMeEncode("ToSingleElementOp::DoCompute",
649 {{"id", ctx->step_id()}});
650 },
651 profiler::kInfo);
652 tensorflow::ResourceTagger tag(kTFDataResourceTag,
653 ctx->op_kernel().type_string());
654 DatasetBase* dataset;
655 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
656
657 IteratorContext::Params params(ctx);
658 ResourceMgr resource_mgr;
659 params.resource_mgr = &resource_mgr;
660 CancellationManager cancellation_manager(ctx->cancellation_manager());
661 params.cancellation_manager = &cancellation_manager;
662
663 IteratorContext iter_ctx(std::move(params));
664 std::unique_ptr<IteratorBase> iterator;
665 TF_RETURN_IF_ERROR(dataset->MakeIterator(
666 &iter_ctx, /*parent=*/nullptr, "SingleElementIterator", &iterator));
667
668 std::vector<Tensor> components;
669 components.reserve(dataset->output_dtypes().size());
670 bool end_of_sequence = false;
671
672 const absl::Time start_time = metrics_collector_.RecordStart();
673 TF_RETURN_IF_ERROR(
674 iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
675 metrics_collector_.RecordStop(start_time, components);
676
677 if (end_of_sequence) {
678 return errors::InvalidArgument("Dataset was empty.");
679 }
680 TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
681 TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
682 for (int i = 0; i < components.size(); ++i) {
683 ctx->set_output(i, components[i]);
684 }
685
686 components.clear();
687 TF_RETURN_IF_ERROR(
688 iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
689 if (!end_of_sequence) {
690 return errors::InvalidArgument("Dataset had more than one element.");
691 }
692 return OkStatus();
693 }
694
695 IteratorMetricsCollector metrics_collector_;
696 UnboundedThreadPool unbounded_threadpool_;
697 DataTypeVector output_types_;
698 std::vector<PartialTensorShape> output_shapes_;
699 };
700
701 class OneShotIteratorOp : public AsyncOpKernel {
702 public:
OneShotIteratorOp(OpKernelConstruction * ctx)703 explicit OneShotIteratorOp(OpKernelConstruction* ctx)
704 : AsyncOpKernel(ctx),
705 background_worker_(ctx->env(), "tf_data_one_shot_iterator"),
706 graph_def_version_(ctx->graph_def_version())
707
708 {
709 string shared_name;
710 OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &shared_name));
711 OP_REQUIRES(ctx, shared_name.empty(),
712 errors::InvalidArgument("OneShotIteratorOp does not currently "
713 "support the 'shared_name' attr."));
714 OP_REQUIRES_OK(ctx,
715 ctx->GetAttr("dataset_factory", &dataset_factory_func_));
716 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
717 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
718 }
719
~OneShotIteratorOp()720 ~OneShotIteratorOp() override {
721 if (iterator_resource_ != nullptr) {
722 iterator_resource_->Unref();
723 if (!cinfo_.resource_manager()
724 ->Delete<IteratorResource>(cinfo_.container(), cinfo_.name())
725 .ok()) {
726 // Do nothing; the resource can have been deleted by session resets.
727 }
728 }
729 }
730
731 // NOTE(mrry): This is based on `ResourceOpKernel<T>::Compute()`,
732 // but due to the fact that `ResourceOpKernel<T>::CreateResource()`
733 // does not provide access to the `OpKernelContext*` and we need
734 // this to invoke the factory function, it's not possible to
735 // implement this kernel by implementing `CreateResource()`.
736 // Furthermore, due to the fact that this kernel might block when
737 // running the initialization function, we must implement this
738 // kernel as an async kernel.
ComputeAsync(OpKernelContext * ctx,DoneCallback done)739 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
740 tensorflow::ResourceTagger tag(kTFDataResourceTag,
741 ctx->op_kernel().type_string());
742 {
743 mutex_lock l(mu_);
744 if (iterator_resource_ == nullptr && initialization_status_.ok()) {
745 // The initialization thread will call `done`.
746 if (!initialization_started_) {
747 // TODO(mrry): Convert the initialization code to use
748 // callbacks instead of wasting a thread.
749 background_worker_.Schedule([this, ctx, done]() { Init(ctx, done); });
750 initialization_started_ = true;
751 } else {
752 done_callbacks_.emplace_back(ctx, std::move(done));
753 }
754 return;
755 }
756 }
757 ProduceOutput(ctx, done);
758 }
759
760 private:
Init(OpKernelContext * ctx,const DoneCallback & done)761 void Init(OpKernelContext* ctx, const DoneCallback& done) {
762 IteratorResource* iterator = nullptr;
763 ContainerInfo cinfo;
764 Status s = TryInit(ctx, &iterator, &cinfo);
765
766 std::vector<std::pair<OpKernelContext*, DoneCallback>> callbacks_to_run;
767 {
768 mutex_lock l(mu_);
769 if (s.ok()) {
770 iterator_resource_ = iterator;
771 cinfo_ = cinfo;
772 }
773 initialization_status_ = s;
774 std::swap(done_callbacks_, callbacks_to_run);
775 }
776
777 for (auto&& ctx_done : callbacks_to_run) {
778 ProduceOutput(ctx_done.first, ctx_done.second);
779 }
780 ProduceOutput(ctx, done);
781 }
782
TryInit(OpKernelContext * ctx,IteratorResource ** iterator,ContainerInfo * cinfo)783 Status TryInit(OpKernelContext* ctx, IteratorResource** iterator,
784 ContainerInfo* cinfo) {
785 TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def()));
786
787 FunctionLibraryRuntime* flr;
788 std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
789 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
790 TF_RETURN_IF_ERROR(
791 ctx->function_library()->Clone(&flib_def, &pflr, &flr, true));
792
793 // Create an IteratorResource that will hold the iterator for this op.
794 TF_RETURN_IF_ERROR(
795 ctx->resource_manager()->LookupOrCreate<IteratorResource>(
796 cinfo->container(), cinfo->name(), iterator,
797 [ctx, flr, this, &flib_def, &pflr](IteratorResource** ret)
798 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
799 *ret = new IteratorResource(
800 ctx->env(), output_dtypes_, output_shapes_,
801 /*device_mgr=*/nullptr, std::move(flib_def),
802 std::move(pflr), flr);
803 return OkStatus();
804 }));
805
806 core::ScopedUnref unref_iterator(*iterator);
807
808 TF_RETURN_IF_ERROR(
809 VerifyTypesMatch(output_dtypes_, (*iterator)->output_dtypes()));
810 TF_RETURN_IF_ERROR(
811 VerifyShapesCompatible(output_shapes_, (*iterator)->output_shapes()));
812
813 // Call the dataset_factory_func_ to create a new dataset,
814 // over which this op will iterate.
815 FunctionLibraryRuntime::Handle f_handle;
816 TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
817 dataset_factory_func_.name(), AttrSlice(&dataset_factory_func_.attr()),
818 &f_handle));
819 FunctionLibraryRuntime::Options opts;
820 opts.cancellation_manager = ctx->cancellation_manager();
821 ScopedStepContainer step_container(opts.step_id, [ctx](const string& name) {
822 ctx->resource_manager()->Cleanup(name).IgnoreError();
823 });
824 opts.step_container = &step_container;
825 opts.runner = ctx->runner();
826 opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
827 std::vector<Tensor> return_values;
828 TF_RETURN_IF_ERROR(ctx->function_library()->RunSync(
829 std::move(opts), f_handle, {}, &return_values));
830 if (return_values.size() != 1 || return_values[0].dtype() != DT_VARIANT ||
831 !TensorShapeUtils::IsScalar(return_values[0].shape())) {
832 return errors::InvalidArgument(
833 "The `dataset_factory` function must return "
834 "a single scalar of dtype DT_VARIANT.");
835 }
836
837 // Create an iterator for the dataset that was created in the
838 // factory function.
839 DatasetBase* dataset;
840 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
841 TF_RETURN_IF_ERROR((*iterator)->SetIteratorFromDataset(ctx, dataset));
842 (*iterator)->Ref();
843 return OkStatus();
844 }
845
ProduceOutput(OpKernelContext * ctx,const DoneCallback & done)846 void ProduceOutput(OpKernelContext* ctx, const DoneCallback& done) {
847 Tensor* handle;
848 OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &handle),
849 done);
850 Status s;
851 {
852 mutex_lock l(mu_);
853 s = initialization_status_;
854 if (s.ok()) {
855 handle->scalar<ResourceHandle>()() =
856 MakeResourceHandle<IteratorResource>(ctx, cinfo_.container(),
857 cinfo_.name());
858 }
859 }
860 OP_REQUIRES_OK_ASYNC(ctx, s, done);
861 done();
862 }
863
864 NameAttrList dataset_factory_func_;
865 DataTypeVector output_dtypes_;
866 std::vector<PartialTensorShape> output_shapes_;
867
868 BackgroundWorker background_worker_;
869
870 mutex mu_;
871 ContainerInfo cinfo_ TF_GUARDED_BY(mu_);
872 IteratorResource* iterator_resource_ TF_GUARDED_BY(mu_) = nullptr;
873
874 bool initialization_started_ TF_GUARDED_BY(mu_) = false;
875 Status initialization_status_ TF_GUARDED_BY(mu_);
876 std::vector<std::pair<OpKernelContext*, DoneCallback>> done_callbacks_
877 TF_GUARDED_BY(mu_);
878 const int graph_def_version_;
879 };
880
881 } // namespace
882
AsAsync()883 AsyncOpKernel* IteratorGetNextOp::AsAsync() {
884 return type_string() == "IteratorGetNextSync" ? nullptr : this;
885 }
886
RecordElementSize(const std::vector<Tensor> element,profiler::TraceMe * traceme)887 void RecordElementSize(const std::vector<Tensor> element,
888 profiler::TraceMe* traceme) {
889 traceme->AppendMetadata([&]() {
890 int64_t element_size = 0;
891 for (const auto& component : element) {
892 element_size += component.TotalBytes();
893 }
894 return profiler::TraceMeEncode({{"element_size", element_size}});
895 });
896 }
897
DoCompute(OpKernelContext * ctx)898 Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
899 VLOG(3) << "IteratorGetNextOp enter. iter_id=" << ctx->frame_iter().iter_id;
900 auto cleanup = gtl::MakeCleanup([ctx] {
901 VLOG(3) << "IteratorGetNextOp exit. iter_id=" << ctx->frame_iter().iter_id;
902 });
903 activity_watcher::ActivityScope activity_scope([ctx = ctx]() {
904 return activity_watcher::ActivityFromContext(
905 ctx, "IteratorGetNextOp::DoCompute",
906 activity_watcher::ActivityCategory::kDatasetOp);
907 });
908 profiler::TraceMe traceme(
909 [&] {
910 return profiler::TraceMeEncode(
911 "IteratorGetNextOp::DoCompute",
912 {{"id", ctx->step_id()}, {"iter_num", ctx->frame_iter().iter_id}});
913 },
914 profiler::kInfo);
915 tensorflow::ResourceTagger tag(kTFDataResourceTag,
916 ctx->op_kernel().type_string());
917 IteratorResource* iterator;
918 TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
919 core::ScopedUnref unref_iterator(iterator);
920 std::vector<Tensor> components;
921 bool end_of_sequence = false;
922
923 TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &components, &end_of_sequence));
924 if (end_of_sequence) {
925 return errors::OutOfRange("End of sequence");
926 }
927 TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
928 TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
929 RecordElementSize(components, &traceme);
930 for (int i = 0; i < components.size(); ++i) {
931 ctx->set_output(i, components[i]);
932 }
933 return OkStatus();
934 }
935
DoCompute(OpKernelContext * ctx)936 Status IteratorGetNextAsOptionalOp::DoCompute(OpKernelContext* ctx) {
937 VLOG(3) << "IteratorGetNextAsOptionalOp enter. iter_id="
938 << ctx->frame_iter().iter_id;
939 auto cleanup = gtl::MakeCleanup([ctx] {
940 VLOG(3) << "IteratorGetNextAsOptionalOp exit. iter_id="
941 << ctx->frame_iter().iter_id;
942 });
943 activity_watcher::ActivityScope activity_scope([ctx = ctx]() {
944 return activity_watcher::ActivityFromContext(
945 ctx, "IteratorGetNextAsOptionalOp::DoCompute",
946 activity_watcher::ActivityCategory::kDatasetOp);
947 });
948 profiler::TraceMe traceme(
949 [&] {
950 return profiler::TraceMeEncode(
951 "IteratorGetNextAsOptionalOp::DoCompute",
952 {{"id", ctx->step_id()}, {"iter_num", ctx->frame_iter().iter_id}});
953 },
954 profiler::kInfo);
955 tensorflow::ResourceTagger tag(kTFDataResourceTag,
956 ctx->op_kernel().type_string());
957 IteratorResource* iterator;
958 TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
959 core::ScopedUnref unref_iterator(iterator);
960 std::vector<Tensor> components;
961 bool end_of_sequence = false;
962
963 TF_RETURN_IF_ERROR(iterator->GetNext(ctx, &components, &end_of_sequence));
964
965 if (end_of_sequence) {
966 return WriteOptionalNoneToOutput(ctx, 0);
967 } else {
968 RecordElementSize(components, &traceme);
969 for (int i = 0; i < components.size(); ++i) {
970 if (components[i].dtype() != output_types_[i]) {
971 return errors::InvalidArgument(
972 "The given optional does not match the expected type for "
973 "component ",
974 i, ". Expected: ", DataTypeString(output_types_[i]),
975 ". Actual: ", DataTypeString(components[i].dtype()), ".");
976 }
977 if (!output_shapes_[i].IsCompatibleWith(components[i].shape())) {
978 return errors::InvalidArgument(
979 "The given optional does not match the expected shape "
980 "for component ",
981 i, ". Expected: ", output_shapes_[i].DebugString(),
982 ". Actual: ", components[i].shape().DebugString(), ".");
983 }
984 }
985 return WriteOptionalWithValueToOutput(ctx, 0, std::move(components));
986 }
987 }
988
Compute(OpKernelContext * ctx)989 void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) {
990 const Tensor& resource_handle_t = ctx->input(0);
991 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
992 errors::InvalidArgument("resource_handle must be a scalar"));
993
994 // Validate that the handle corresponds to a real resource, and
995 // that it is an IteratorResource.
996 IteratorResource* iterator_resource;
997 OP_REQUIRES_OK(
998 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
999 iterator_resource->Unref();
1000
1001 Tensor* string_handle_t;
1002 OP_REQUIRES_OK(ctx,
1003 ctx->allocate_output(0, TensorShape({}), &string_handle_t));
1004 string_handle_t->scalar<tstring>()() =
1005 resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
1006 }
1007
IteratorFromStringHandleOp(OpKernelConstruction * ctx)1008 IteratorFromStringHandleOp::IteratorFromStringHandleOp(
1009 OpKernelConstruction* ctx)
1010 : OpKernel(ctx) {
1011 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
1012 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
1013 OP_REQUIRES(
1014 ctx,
1015 output_dtypes_.empty() || output_shapes_.empty() ||
1016 output_dtypes_.size() == output_shapes_.size(),
1017 errors::InvalidArgument("If both 'output_types' and 'output_shapes' "
1018 "are set, they must have the same length."));
1019 }
1020
Compute(OpKernelContext * ctx)1021 void IteratorFromStringHandleOp::Compute(OpKernelContext* ctx) {
1022 const Tensor& string_handle_t = ctx->input(0);
1023 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()),
1024 errors::InvalidArgument("string_handle must be a scalar"));
1025
1026 ResourceHandle resource_handle;
1027 OP_REQUIRES(
1028 ctx, resource_handle.ParseFromString(string_handle_t.scalar<tstring>()()),
1029 errors::InvalidArgument(
1030 "Could not parse string_handle as a valid ResourceHandle"));
1031
1032 OP_REQUIRES(
1033 ctx, resource_handle.device() == ctx->device()->attributes().name(),
1034 errors::InvalidArgument("Attempted create an iterator on device \"",
1035 ctx->device()->attributes().name(),
1036 "\" from handle defined on device \"",
1037 resource_handle.device(), "\""));
1038
1039 // Validate that the handle corresponds to a real resource, and
1040 // that it is an IteratorResource.
1041 IteratorResource* iterator_resource;
1042 OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &iterator_resource));
1043 core::ScopedUnref unref_iterator(iterator_resource);
1044 if (!output_dtypes_.empty()) {
1045 OP_REQUIRES_OK(ctx, VerifyTypesMatch(output_dtypes_,
1046 iterator_resource->output_dtypes()));
1047 }
1048 if (!output_shapes_.empty()) {
1049 OP_REQUIRES_OK(ctx,
1050 VerifyShapesCompatible(output_shapes_,
1051 iterator_resource->output_shapes()));
1052 }
1053
1054 Tensor* resource_handle_t;
1055 OP_REQUIRES_OK(ctx,
1056 ctx->allocate_output(0, TensorShape({}), &resource_handle_t));
1057 resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
1058 }
1059
SerializeIteratorOp(OpKernelConstruction * ctx)1060 SerializeIteratorOp::SerializeIteratorOp(OpKernelConstruction* ctx)
1061 : OpKernel(ctx) {
1062 if (ctx->HasAttr(kExternalStatePolicy)) {
1063 int64_t state_change_option;
1064 OP_REQUIRES_OK(ctx,
1065 ctx->GetAttr(kExternalStatePolicy, &state_change_option));
1066 external_state_policy_ =
1067 SerializationContext::ExternalStatePolicy(state_change_option);
1068 }
1069 }
1070
Compute(OpKernelContext * ctx)1071 void SerializeIteratorOp::Compute(OpKernelContext* ctx) {
1072 tensorflow::ResourceTagger tag(kTFDataResourceTag,
1073 ctx->op_kernel().type_string());
1074 const Tensor& resource_handle_t = ctx->input(0);
1075 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
1076 errors::InvalidArgument("resource_handle must be a scalar"));
1077 // Validate that the handle corresponds to a real resource, and
1078 // that it is an IteratorResource.
1079 IteratorResource* iterator_resource;
1080 OP_REQUIRES_OK(
1081 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
1082 core::ScopedUnref unref_iterator(iterator_resource);
1083 IteratorVariantSerializer serializer;
1084 SerializationContext::Params params(ctx);
1085 params.external_state_policy = external_state_policy_;
1086 SerializationContext serialization_ctx(params);
1087 OP_REQUIRES_OK(ctx, serializer.InitializeFromIterator(&serialization_ctx,
1088 iterator_resource));
1089 Tensor* serialized_t;
1090 OP_REQUIRES_OK(ctx,
1091 ctx->allocate_output(0, TensorShape({serializer.NumTensors()}),
1092 &serialized_t));
1093 OP_REQUIRES_OK(ctx, serializer.Serialize(serialized_t));
1094 }
1095
Compute(OpKernelContext * ctx)1096 void DeserializeIteratorOp::Compute(OpKernelContext* ctx) {
1097 tensorflow::ResourceTagger tag(kTFDataResourceTag,
1098 ctx->op_kernel().type_string());
1099 // Validate that the handle corresponds to a real resource, and
1100 // that it is an IteratorResource.
1101 IteratorResource* iterator_resource;
1102 OP_REQUIRES_OK(
1103 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
1104 core::ScopedUnref unref_iterator(iterator_resource);
1105 const Tensor* serialized_t;
1106 OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized_t));
1107 IteratorVariantSerializer serializer;
1108 OP_REQUIRES_OK(ctx, serializer.InitFromTensor(serialized_t));
1109 Status s = iterator_resource->Restore(ctx, serializer.GetReader());
1110 if (!s.ok()) {
1111 OP_REQUIRES_OK(
1112 ctx,
1113 errors::CreateWithUpdatedMessage(
1114 s, absl::StrCat(
1115 "Failed to restore dataset iterator from checkpoint: ",
1116 s.error_message(),
1117 ". Make sure the dataset definition has not changed between "
1118 "the process that saved the checkpoint and the process that "
1119 "is restoring it.")));
1120 }
1121 }
1122
1123 namespace {
1124
1125 REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
1126 REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_CPU).Priority(2),
1127 IteratorHandleOp);
1128 REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_GPU).Priority(1),
1129 IteratorHandleOp);
1130 REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU).Priority(2),
1131 MakeIteratorOp);
1132 REGISTER_KERNEL_BUILDER(
1133 Name("MakeIterator").Device(DEVICE_GPU).Priority(1).HostMemory("dataset"),
1134 MakeIteratorOp);
1135 REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE_CPU).Priority(2),
1136 DeleteIteratorOp);
1137 REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE_GPU).Priority(1),
1138 DeleteIteratorOp);
1139 REGISTER_KERNEL_BUILDER(
1140 Name("AnonymousIterator").Device(DEVICE_CPU).Priority(2),
1141 AnonymousIteratorHandleOp);
1142 REGISTER_KERNEL_BUILDER(
1143 Name("AnonymousIterator").Device(DEVICE_GPU).Priority(1),
1144 AnonymousIteratorHandleOp);
1145 REGISTER_KERNEL_BUILDER(
1146 Name("AnonymousIteratorV2").Device(DEVICE_CPU).Priority(2),
1147 AnonymousIteratorHandleOp);
1148 REGISTER_KERNEL_BUILDER(
1149 Name("AnonymousIteratorV2").Device(DEVICE_GPU).Priority(1),
1150 AnonymousIteratorHandleOp);
1151 REGISTER_KERNEL_BUILDER(
1152 Name("AnonymousIteratorV3").Device(DEVICE_CPU).Priority(2),
1153 AnonymousIteratorHandleOp);
1154 REGISTER_KERNEL_BUILDER(
1155 Name("AnonymousIteratorV3").Device(DEVICE_GPU).Priority(1),
1156 AnonymousIteratorHandleOp);
1157 REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
1158 ToSingleElementOp);
1159 REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
1160 OneShotIteratorOp);
1161 REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU).Priority(2),
1162 IteratorGetNextOp);
1163 REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_GPU).Priority(1),
1164 IteratorGetNextOp);
1165 REGISTER_KERNEL_BUILDER(
1166 Name("IteratorGetNextSync").Device(DEVICE_CPU).Priority(2),
1167 IteratorGetNextOp);
1168 REGISTER_KERNEL_BUILDER(
1169 Name("IteratorGetNextSync").Device(DEVICE_GPU).Priority(1),
1170 IteratorGetNextOp);
1171 REGISTER_KERNEL_BUILDER(
1172 Name("IteratorGetNextAsOptional").Device(DEVICE_CPU).Priority(2),
1173 IteratorGetNextAsOptionalOp);
1174 REGISTER_KERNEL_BUILDER(
1175 Name("IteratorGetNextAsOptional").Device(DEVICE_GPU).Priority(1),
1176 IteratorGetNextAsOptionalOp);
1177 REGISTER_KERNEL_BUILDER(
1178 Name("IteratorToStringHandle").Device(DEVICE_CPU).Priority(2),
1179 IteratorToStringHandleOp);
1180 REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle")
1181 .Device(DEVICE_GPU)
1182 .HostMemory("string_handle")
1183 .Priority(1),
1184 IteratorToStringHandleOp);
1185 REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU),
1186 IteratorFromStringHandleOp);
1187 REGISTER_KERNEL_BUILDER(
1188 Name("IteratorFromStringHandleV2").Device(DEVICE_CPU).Priority(2),
1189 IteratorFromStringHandleOp);
1190 REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2")
1191 .Device(DEVICE_GPU)
1192 .HostMemory("string_handle")
1193 .Priority(1),
1194 IteratorFromStringHandleOp);
1195 REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
1196 SerializeIteratorOp);
1197 REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
1198 DeserializeIteratorOp);
1199
1200 } // namespace
1201
1202 } // namespace data
1203 } // namespace tensorflow
1204