1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <atomic> 16 #include <cstdint> 17 #include <deque> 18 #include <functional> 19 #include <utility> 20 21 #include "absl/time/time.h" 22 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h" 23 #include "tensorflow/core/common_runtime/process_function_library_runtime.h" 24 #include "tensorflow/core/data/dataset_utils.h" 25 #include "tensorflow/core/data/metric_utils.h" 26 #include "tensorflow/core/data/root_dataset.h" 27 #include "tensorflow/core/data/unbounded_thread_pool.h" 28 #include "tensorflow/core/framework/cancellation.h" 29 #include "tensorflow/core/framework/dataset.h" 30 #include "tensorflow/core/framework/function.h" 31 #include "tensorflow/core/framework/function_handle_cache.h" 32 #include "tensorflow/core/framework/op_kernel.h" 33 #include "tensorflow/core/framework/resource_op_kernel.h" 34 #include "tensorflow/core/framework/types.h" 35 #include "tensorflow/core/kernels/data/iterator_ops.h" 36 #include "tensorflow/core/kernels/ops_util.h" 37 #include "tensorflow/core/lib/core/refcount.h" 38 #include "tensorflow/core/lib/gtl/cleanup.h" 39 #include "tensorflow/core/lib/random/random.h" 40 #include "tensorflow/core/platform/env.h" 41 #include "tensorflow/core/platform/refcount.h" 42 #include "tensorflow/core/util/device_name_utils.h" 43 44 namespace tensorflow { 45 namespace data { 46 namespace { 47 48 const char kAnonymousMultiDeviceIterator[] = "AnonymousMultiDeviceIterator"; 49 const char kAnonymousMultiDeviceIteratorV3[] = "AnonymousMultiDeviceIteratorV3"; 50 const char kDevices[] = "devices"; 51 const char kOutputShapes[] = "output_shapes"; 52 const char kOutputTypes[] = "output_types"; 53 54 struct HostBufferElement { 55 Status status; 56 bool end_of_sequence; 57 std::vector<Tensor> value; 58 }; 59 60 using MultiDeviceIteratorCallback = 61 std::function<void(const HostBufferElement&)>; 62 63 // MultiDeviceIterator provides the ability for multiple devices to fetch from 64 // one iterator in a roundrobin sequence, which is deterministic. This means 65 // that, for exmaple, starting from the beginning GetNextFromShard(0) always 66 // gets the first element and GetNextFromShard(1) always gets the second 67 // element, even if GetNextFromShard(1) is called before GetNextFromShard(0). 68 // 69 // Note on cancellation: 70 // * MultiDeviceIterator can be cancelled as a whole by calling Reset() or 71 // cancel MultiDeviceIterator::cancellation_manager(). 72 // * GetNextFromShard can be cancelled independently. Cancelling 73 // GetNextFromShard for one shard doesn't cancel the underlying prefetching, 74 // nor does it other calls of GetNextFromShard. 75 class MultiDeviceIterator : public ResourceBase { 76 public: MultiDeviceIterator(Env * env,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes,const std::vector<string> & devices,std::unique_ptr<FunctionLibraryDefinition> flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * flr,std::unique_ptr<FunctionHandleCache> function_handle_cache)77 MultiDeviceIterator( 78 Env* env, const DataTypeVector& output_types, 79 const std::vector<PartialTensorShape>& output_shapes, 80 const std::vector<string>& devices, 81 std::unique_ptr<FunctionLibraryDefinition> flib_def, 82 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr, 83 FunctionLibraryRuntime* flr, 84 std::unique_ptr<FunctionHandleCache> function_handle_cache) 85 : metrics_collector_(flr ? flr->device()->device_type() : DEVICE_DEFAULT, 86 *env), 87 unbounded_thread_pool_(env, "tf_data_multi_device_iterator_resource"), 88 output_types_(output_types), 89 output_shapes_(output_shapes), 90 devices_(devices), 91 flib_def_(std::move(flib_def)), 92 flr_(flr), 93 pflr_(std::move(pflr)), 94 function_handle_cache_(std::move(function_handle_cache)) { 95 DCHECK(flr_ != nullptr); 96 VLOG(2) << "Creating multi-device iterator."; 97 } 98 ~MultiDeviceIterator()99 ~MultiDeviceIterator() override { 100 VLOG(2) << "Destroying multi-device iterator."; 101 } 102 DebugString() const103 string DebugString() const override { 104 return strings::StrCat("MultiDeviceIterator for ", devices_.size(), 105 " devices"); 106 } 107 Init(std::unique_ptr<IteratorBase> iterator,int64_t max_buffer_size,int64_t * incarnation_id,DatasetBase * dataset)108 Status Init(std::unique_ptr<IteratorBase> iterator, int64_t max_buffer_size, 109 int64_t* incarnation_id, DatasetBase* dataset) { 110 if (iterator) { 111 TF_RETURN_IF_ERROR( 112 VerifyTypesMatch(output_types_, iterator->output_dtypes())); 113 TF_RETURN_IF_ERROR( 114 VerifyShapesCompatible(output_shapes_, iterator->output_shapes())); 115 } 116 117 mutex_lock l(mu_); 118 if (multi_device_buffer_) { 119 multi_device_buffer_->Reset(); 120 } 121 dataset->Ref(); 122 dataset_.reset(dataset); 123 124 ++incarnation_id_; 125 *incarnation_id = incarnation_id_; 126 127 multi_device_buffer_ = std::make_unique<MultiDeviceBuffer>( 128 devices_.size(), max_buffer_size, incarnation_id_, std::move(iterator), 129 this); 130 return OkStatus(); 131 } 132 GetNextFromShard(OpKernelContext * ctx,int shard_num,int64_t incarnation_id,MultiDeviceIteratorCallback callback)133 Status GetNextFromShard(OpKernelContext* ctx, int shard_num, 134 int64_t incarnation_id, 135 MultiDeviceIteratorCallback callback) { 136 tf_shared_lock l(mu_); 137 IteratorContext::Params params(ctx); 138 params.flr = flr_; 139 params.function_handle_cache = function_handle_cache_.get(); 140 params.resource_mgr = &resource_mgr_; 141 params.thread_factory = unbounded_thread_pool_.get_thread_factory(); 142 params.thread_pool = &unbounded_thread_pool_; 143 params.cancellation_manager = ctx->cancellation_manager(); 144 IteratorContext iter_ctx(std::move(params)); 145 multi_device_buffer_->GetNextFromShard(&iter_ctx, shard_num, incarnation_id, 146 std::move(callback)); 147 return OkStatus(); 148 } 149 output_types() const150 const DataTypeVector& output_types() const { return output_types_; } 151 output_shapes() const152 const std::vector<PartialTensorShape>& output_shapes() const { 153 return output_shapes_; 154 } 155 flr()156 FunctionLibraryRuntime* const flr() { 157 tf_shared_lock l(mu_); 158 return flr_; 159 } 160 function_handle_cache()161 FunctionHandleCache* function_handle_cache() { 162 return function_handle_cache_.get(); 163 } 164 resource_mgr()165 ResourceMgr* resource_mgr() { return &resource_mgr_; } 166 cancellation_manager()167 CancellationManager* cancellation_manager() { return &cancellation_manager_; } 168 metrics_collector()169 IteratorMetricsCollector& metrics_collector() { return metrics_collector_; } 170 171 private: 172 // A private class that uses a background thread to keep a per device buffer 173 // full. 174 class MultiDeviceBuffer { 175 public: MultiDeviceBuffer(size_t size,int64_t max_buffer_size,int64_t incarnation_id,std::unique_ptr<IteratorBase> host_iterator,MultiDeviceIterator * parent)176 MultiDeviceBuffer(size_t size, int64_t max_buffer_size, 177 int64_t incarnation_id, 178 std::unique_ptr<IteratorBase> host_iterator, 179 MultiDeviceIterator* parent) 180 : buffer_(size), 181 size_(size), 182 max_buffer_size_(max_buffer_size), 183 incarnation_id_(incarnation_id), 184 host_iterator_(std::move(host_iterator)), 185 parent_(parent) {} 186 ~MultiDeviceBuffer()187 ~MultiDeviceBuffer() { 188 { 189 mutex_lock l(mu_); 190 if (!background_thread_started_) return; 191 } 192 Reset(); 193 } 194 Reset()195 void Reset() TF_LOCKS_EXCLUDED(mu_) { 196 { 197 mutex_lock l(mu_); 198 if (background_thread_ && !background_thread_finished_) { 199 cancellation_manager_.StartCancel(); 200 // Wake up the background thread. 201 for (int i = 0; i < size_; ++i) { 202 buffer_[i].cond_var.notify_all(); 203 } 204 205 // Make sure background thread has finished first. 206 while (!background_thread_finished_) { 207 shutdown_cond_var_.wait(l); 208 } 209 } 210 } 211 RunPendingCallbacks(); 212 } 213 GetNextFromShard(IteratorContext * ctx,int shard_num,int64_t incarnation_id,MultiDeviceIteratorCallback callback)214 void GetNextFromShard(IteratorContext* ctx, int shard_num, 215 int64_t incarnation_id, 216 MultiDeviceIteratorCallback callback) { 217 HostBufferElement elem; 218 if (incarnation_id_ != incarnation_id) { 219 elem.status = errors::InvalidArgument( 220 "Invalid incarnation id. Provided: ", incarnation_id, 221 "; Expected: ", incarnation_id_); 222 callback(elem); 223 return; 224 } 225 226 bool produced_output = false; 227 { 228 mutex_lock l(mu_); 229 if (cancellation_manager_.IsCancelled()) { 230 elem.status = errors::Cancelled("Cancelled Multidevice iterator"); 231 callback(elem); 232 return; 233 } 234 235 EnsureBackgroundThreadStarted(ctx); 236 237 if (!buffer_[shard_num].data.empty()) { 238 produced_output = true; 239 std::swap(elem, buffer_[shard_num].data.front()); 240 buffer_[shard_num].data.pop_front(); 241 // Wake up background thread if it is blocked on this element. 242 if (buffer_[shard_num].data.size() == max_buffer_size_ - 1) { 243 buffer_[shard_num].cond_var.notify_all(); 244 } 245 } else { 246 if (end_of_iterator_) { 247 produced_output = true; 248 elem.end_of_sequence = true; 249 } else { 250 auto callback_container = 251 std::make_shared<HostBuffer::CallbackContainer>( 252 std::move(callback)); 253 elem.status = RegisterCancellationCallback( 254 ctx->cancellation_manager(), 255 [callback_container]() { 256 if (callback_container->is_called.exchange(true)) { 257 return; 258 } 259 HostBufferElement elem; 260 elem.status = 261 errors::Cancelled("GetNextFromShard was cancelled"); 262 callback_container->callback(elem); 263 }, 264 &callback_container->deregister_cancellation); 265 if (!elem.status.ok()) { 266 callback_container->callback(elem); 267 return; 268 } 269 buffer_[shard_num].callbacks.push_back( 270 std::move(callback_container)); 271 buffer_[shard_num].cond_var.notify_all(); 272 callback = nullptr; 273 } 274 } 275 } 276 277 if (produced_output) { 278 callback(elem); 279 } 280 } 281 282 private: EnsureBackgroundThreadStarted(IteratorContext * ctx)283 void EnsureBackgroundThreadStarted(IteratorContext* ctx) 284 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { 285 if (!background_thread_) { 286 IteratorContext::Params params(ctx); 287 params.cancellation_manager = &cancellation_manager_; 288 background_thread_ = 289 parent_->unbounded_thread_pool_.get_thread_factory()->StartThread( 290 "tf_data_multi_device_iterator", 291 std::bind( 292 &MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread, 293 this, 294 std::make_shared<IteratorContext>(std::move(params)))); 295 } 296 } 297 RunPendingCallbacks()298 void RunPendingCallbacks() TF_LOCKS_EXCLUDED(mu_) { 299 // Run all remaining callbacks. 300 301 std::vector<std::shared_ptr<HostBuffer::CallbackContainer>> 302 callback_containers; 303 std::vector<HostBufferElement> cancellation_elements; 304 { 305 mutex_lock l(mu_); 306 307 for (int i = 0; i < size_; ++i) { 308 while (!buffer_[i].callbacks.empty()) { 309 if (buffer_[i].callbacks.front()->is_called.exchange(true)) { 310 buffer_[i].callbacks.pop_front(); 311 continue; 312 } 313 if (buffer_[i].data.empty()) { 314 HostBufferElement elem; 315 if (end_of_iterator_) { 316 elem.end_of_sequence = true; 317 } else { 318 elem.status = 319 errors::Cancelled("Cancelled and buffer not filled."); 320 } 321 cancellation_elements.push_back(std::move(elem)); 322 } else { 323 cancellation_elements.push_back( 324 std::move(buffer_[i].data.front())); 325 buffer_[i].data.pop_front(); 326 } 327 callback_containers.push_back( 328 std::move(buffer_[i].callbacks.front())); 329 buffer_[i].callbacks.pop_front(); 330 } 331 } 332 } 333 for (int i = 0; i < callback_containers.size(); ++i) { 334 if (callback_containers[i]->deregister_cancellation != nullptr) { 335 callback_containers[i]->deregister_cancellation(); 336 } 337 // We invoke the callback regardless of whether deregistration succeeds 338 // or not, because we have set is_called=true previous which effectively 339 // disables the cancellation callback. 340 callback_containers[i]->callback(cancellation_elements[i]); 341 } 342 } 343 BackgroundThread(std::shared_ptr<IteratorContext> ctx)344 void BackgroundThread(std::shared_ptr<IteratorContext> ctx) { 345 { 346 mutex_lock l(mu_); 347 background_thread_started_ = true; 348 } 349 int shard_to_fetch = 0; 350 while (true) { 351 HostBufferElement elem; 352 bool end_of_iterator = false; 353 354 { 355 mutex_lock l(mu_); 356 while (!cancellation_manager_.IsCancelled() && 357 buffer_[shard_to_fetch].data.size() >= max_buffer_size_ && 358 buffer_[shard_to_fetch].callbacks.empty()) { 359 buffer_[shard_to_fetch].cond_var.wait(l); 360 } 361 362 if (cancellation_manager_.IsCancelled()) { 363 background_thread_finished_ = true; 364 shutdown_cond_var_.notify_all(); 365 return; 366 } 367 } 368 369 elem.status = host_iterator_->GetNext(ctx.get(), &elem.value, 370 &elem.end_of_sequence); 371 372 if (elem.status.ok() && elem.end_of_sequence) { 373 end_of_iterator = true; 374 } 375 376 std::shared_ptr<HostBuffer::CallbackContainer> callback_container; 377 { 378 mutex_lock l(mu_); 379 // Try to find a callback, else just push stuff into buffer. 380 if (!buffer_[shard_to_fetch].callbacks.empty()) { 381 while (!buffer_[shard_to_fetch].callbacks.empty()) { 382 if (buffer_[shard_to_fetch].callbacks.front()->is_called.exchange( 383 true)) { 384 // This callback is already cancelled. 385 buffer_[shard_to_fetch].callbacks.pop_front(); 386 continue; 387 } else { 388 callback_container = 389 std::move(buffer_[shard_to_fetch].callbacks.front()); 390 buffer_[shard_to_fetch].callbacks.pop_front(); 391 break; 392 } 393 } 394 } else { 395 buffer_[shard_to_fetch].data.push_back(std::move(elem)); 396 elem = HostBufferElement(); 397 } 398 } 399 400 if (callback_container) { 401 if (callback_container->deregister_cancellation != nullptr) { 402 callback_container->deregister_cancellation(); 403 } 404 (*ctx->runner())(std::bind(std::move(callback_container->callback), 405 std::move(elem))); 406 } 407 408 // Finish off the thread if we reach the end of the iterator. Runs 409 // pending callbacks. 410 if (end_of_iterator) { 411 { 412 mutex_lock l(mu_); 413 background_thread_finished_ = true; 414 end_of_iterator_ = true; 415 shutdown_cond_var_.notify_all(); 416 } 417 RunPendingCallbacks(); 418 return; 419 } 420 shard_to_fetch = (shard_to_fetch + 1) % size_; 421 } 422 } 423 424 struct HostBuffer { 425 condition_variable cond_var; 426 std::deque<HostBufferElement> data; 427 struct CallbackContainer { 428 MultiDeviceIteratorCallback callback; 429 // Whether callback is already called, either by the background thread 430 // of by the cancellation callback. 431 std::atomic<bool> is_called; 432 std::function<void()> deregister_cancellation; CallbackContainertensorflow::data::__anon7578bfd70111::MultiDeviceIterator::MultiDeviceBuffer::HostBuffer::CallbackContainer433 explicit CallbackContainer(MultiDeviceIteratorCallback&& callback) 434 : callback(std::move(callback)), is_called(false) {} 435 }; 436 // The CallbackContainer is shared with the cancellation callback. 437 std::deque<std::shared_ptr<CallbackContainer>> callbacks; 438 }; 439 440 mutex mu_; 441 bool background_thread_finished_ TF_GUARDED_BY(mu_) = false; 442 bool background_thread_started_ TF_GUARDED_BY(mu_) = false; 443 bool end_of_iterator_ TF_GUARDED_BY(mu_) = false; 444 condition_variable shutdown_cond_var_ TF_GUARDED_BY(mu_); 445 446 std::vector<HostBuffer> buffer_; 447 448 const size_t size_; 449 const int64_t max_buffer_size_; 450 const int64_t incarnation_id_; 451 CancellationManager cancellation_manager_; 452 const std::unique_ptr<IteratorBase> host_iterator_; 453 MultiDeviceIterator* const parent_; // Not owned. 454 std::unique_ptr<Thread> background_thread_ TF_GUARDED_BY(mu_); 455 }; 456 457 IteratorMetricsCollector metrics_collector_; 458 UnboundedThreadPool unbounded_thread_pool_; 459 460 mutex mu_; 461 const DataTypeVector output_types_; 462 const std::vector<PartialTensorShape> output_shapes_; 463 const std::vector<string> devices_; 464 const std::unique_ptr<FunctionLibraryDefinition> flib_def_; 465 FunctionLibraryRuntime* const flr_ = nullptr; // not owned. 466 const std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; 467 const std::unique_ptr<FunctionHandleCache> function_handle_cache_; 468 ResourceMgr resource_mgr_; 469 CancellationManager cancellation_manager_; 470 471 int64_t incarnation_id_ TF_GUARDED_BY(mu_) = 0; 472 std::unique_ptr<MultiDeviceBuffer> multi_device_buffer_ TF_GUARDED_BY(mu_); 473 core::RefCountPtr<DatasetBase> dataset_; 474 }; 475 476 // Used to generate unique names for anonymous multi device iterators. 477 static std::atomic<int64_t> current_id_; 478 479 // Just creates a MultiDeviceIterator and returns it. 480 class MultiDeviceIteratorHandleOp : public OpKernel { 481 public: MultiDeviceIteratorHandleOp(OpKernelConstruction * ctx)482 explicit MultiDeviceIteratorHandleOp(OpKernelConstruction* ctx) 483 : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) { 484 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_)); 485 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_)); 486 OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_)); 487 OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_)); 488 OP_REQUIRES_OK(ctx, ctx->GetAttr(kDevices, &devices_)); 489 } 490 491 // The resource is deleted from the resource manager only when it is private 492 // to kernel. ~MultiDeviceIteratorHandleOp()493 ~MultiDeviceIteratorHandleOp() override { 494 if (resource_ != nullptr) { 495 resource_->Unref(); 496 if (cinfo_.resource_is_private_to_kernel()) { 497 if (!cinfo_.resource_manager() 498 ->template Delete<MultiDeviceIterator>(cinfo_.container(), 499 cinfo_.name()) 500 .ok()) { 501 // Do nothing; the resource can have been deleted by session resets. 502 } 503 } 504 } 505 } 506 Compute(OpKernelContext * context)507 void Compute(OpKernelContext* context) override TF_LOCKS_EXCLUDED(mu_) { 508 string unique_name = cinfo_.name(); 509 string container_name = cinfo_.container(); 510 { 511 mutex_lock l(mu_); 512 if (resource_ == nullptr) { 513 FunctionLibraryRuntime* flr; 514 std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr); 515 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr); 516 OP_REQUIRES_OK(context, context->function_library()->Clone( 517 &flib_def, &pflr, &flr)); 518 auto function_handle_cache = std::make_unique<FunctionHandleCache>(flr); 519 ResourceMgr* mgr = context->resource_manager(); 520 OP_REQUIRES_OK(context, cinfo_.Init(mgr, def())); 521 522 MultiDeviceIterator* resource; 523 524 if (name_ == ResourceHandle::ANONYMOUS_NAME) { 525 unique_name = strings::StrCat("_AnonymousMultiDeviceIterator", 526 current_id_.fetch_add(1)); 527 container_name = kAnonymousMultiDeviceIterator; 528 resource = new MultiDeviceIterator( 529 context->env(), output_types_, output_shapes_, devices_, 530 std::move(flib_def), std::move(pflr), flr, 531 std::move(function_handle_cache)); 532 // NOTE: `mgr->Create()` transfers the one reference on `resource` to 533 // `mgr`. 534 OP_REQUIRES_OK(context, mgr->Create<MultiDeviceIterator>( 535 container_name, unique_name, resource)); 536 } else { 537 unique_name = cinfo_.name(); 538 container_name = cinfo_.container(); 539 OP_REQUIRES_OK(context, 540 mgr->LookupOrCreate<MultiDeviceIterator>( 541 container_name, unique_name, &resource, 542 [this, context, flr, &flib_def, &pflr, 543 &function_handle_cache](MultiDeviceIterator** ret) 544 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { 545 *ret = new MultiDeviceIterator( 546 context->env(), output_types_, 547 output_shapes_, devices_, 548 std::move(flib_def), std::move(pflr), 549 flr, std::move(function_handle_cache)); 550 return OkStatus(); 551 })); 552 Status s = VerifyResource(resource); 553 if (TF_PREDICT_FALSE(!s.ok())) { 554 resource->Unref(); 555 context->SetStatus(s); 556 return; 557 } 558 resource_ = resource; 559 } 560 } 561 } 562 OP_REQUIRES_OK(context, MakeResourceHandleToOutput( 563 context, 0, container_name, unique_name, 564 TypeIndex::Make<MultiDeviceIterator>())); 565 } 566 567 private: 568 // During the first Compute(), resource is either created or looked up using 569 // shared_name. In the latter case, the resource found should be verified if 570 // it is compatible with this op's configuration. The verification may fail in 571 // cases such as two graphs asking queues of the same shared name to have 572 // inconsistent capacities. VerifyResource(MultiDeviceIterator * resource)573 Status VerifyResource(MultiDeviceIterator* resource) { 574 TF_RETURN_IF_ERROR( 575 VerifyTypesMatch(output_types_, resource->output_types())); 576 TF_RETURN_IF_ERROR( 577 VerifyShapesCompatible(output_shapes_, resource->output_shapes())); 578 return OkStatus(); 579 } 580 581 mutex mu_; 582 ContainerInfo cinfo_; // Written once under mu_ then constant afterwards. 583 MultiDeviceIterator* resource_ TF_GUARDED_BY(mu_) = nullptr; 584 DataTypeVector output_types_; 585 std::vector<PartialTensorShape> output_shapes_; 586 const int graph_def_version_; 587 string name_; 588 string container_; 589 std::vector<string> devices_; 590 }; 591 592 REGISTER_KERNEL_BUILDER(Name("MultiDeviceIterator").Device(DEVICE_CPU), 593 MultiDeviceIteratorHandleOp); 594 595 class AnonymousMultiDeviceIteratorOp 596 : public AnonymousResourceOp<MultiDeviceIterator> { 597 public: AnonymousMultiDeviceIteratorOp(OpKernelConstruction * ctx)598 explicit AnonymousMultiDeviceIteratorOp(OpKernelConstruction* ctx) 599 : AnonymousResourceOp<MultiDeviceIterator>( 600 ctx, 601 /* ref_counting */ true, 602 /* Only V1 returns a deleter */ 603 /* return_deleter */ 604 ctx->def().op() == kAnonymousMultiDeviceIterator) { 605 OP_REQUIRES_OK(ctx, ctx->GetAttr(kDevices, &devices_)); 606 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_)); 607 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_)); 608 } 609 610 private: name()611 string name() override { return kAnonymousMultiDeviceIterator; } 612 CreateResource(OpKernelContext * ctx,std::unique_ptr<FunctionLibraryDefinition> flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * lib,MultiDeviceIterator ** resource)613 Status CreateResource(OpKernelContext* ctx, 614 std::unique_ptr<FunctionLibraryDefinition> flib_def, 615 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr, 616 FunctionLibraryRuntime* lib, 617 MultiDeviceIterator** resource) override { 618 auto function_handle_cache = std::make_unique<FunctionHandleCache>(lib); 619 *resource = 620 new MultiDeviceIterator(ctx->env(), output_dtypes_, output_shapes_, 621 devices_, std::move(flib_def), std::move(pflr), 622 lib, std::move(function_handle_cache)); 623 return OkStatus(); 624 } 625 626 std::vector<string> devices_; 627 DataTypeVector output_dtypes_; 628 std::vector<PartialTensorShape> output_shapes_; 629 }; 630 631 REGISTER_KERNEL_BUILDER(Name(kAnonymousMultiDeviceIterator).Device(DEVICE_CPU), 632 AnonymousMultiDeviceIteratorOp); 633 REGISTER_KERNEL_BUILDER( 634 Name(kAnonymousMultiDeviceIteratorV3).Device(DEVICE_CPU), 635 AnonymousMultiDeviceIteratorOp); 636 637 // Calls init on the MultiDeviceIterator. 638 class MultiDeviceIteratorInitOp : public OpKernel { 639 public: MultiDeviceIteratorInitOp(OpKernelConstruction * ctx)640 explicit MultiDeviceIteratorInitOp(OpKernelConstruction* ctx) 641 : OpKernel(ctx) {} 642 Compute(OpKernelContext * ctx)643 void Compute(OpKernelContext* ctx) override { 644 const Tensor* tensor_max_buffer_size; 645 OP_REQUIRES_OK(ctx, ctx->input("max_buffer_size", &tensor_max_buffer_size)); 646 int64_t max_buffer_size = tensor_max_buffer_size->scalar<int64_t>()(); 647 648 DatasetBase* dataset; 649 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset)); 650 core::RefCountPtr<MultiDeviceIterator> resource; 651 OP_REQUIRES_OK(ctx, 652 LookupResource(ctx, HandleFromInput(ctx, 1), &resource)); 653 654 IteratorContext::Params params(ctx); 655 params.flr = resource->flr(); 656 params.function_handle_cache = resource->function_handle_cache(); 657 params.resource_mgr = resource->resource_mgr(); 658 params.cancellation_manager = resource->cancellation_manager(); 659 std::function<void()> deregister_fn; 660 OP_REQUIRES_OK( 661 ctx, RegisterCancellationCallback( 662 ctx->cancellation_manager(), 663 [cm = params.cancellation_manager]() { cm->StartCancel(); }, 664 &deregister_fn)); 665 auto cleanup = gtl::MakeCleanup(std::move(deregister_fn)); 666 IteratorContext iter_ctx(std::move(params)); 667 668 std::unique_ptr<IteratorBase> iterator; 669 DatasetBase* finalized_dataset; 670 OP_REQUIRES_OK(ctx, FinalizeDataset(ctx, dataset, &finalized_dataset)); 671 OP_REQUIRES_OK(ctx, finalized_dataset->MakeIterator(std::move(iter_ctx), 672 /*parent=*/nullptr, 673 "Iterator", &iterator)); 674 core::ScopedUnref unref(finalized_dataset); 675 int64_t incarnation_id; 676 OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size, 677 &incarnation_id, dataset)); 678 Tensor tensor_incarnation_id(DT_INT64, TensorShape({})); 679 tensor_incarnation_id.scalar<int64_t>()() = incarnation_id; 680 OP_REQUIRES_OK(ctx, 681 ctx->set_output("incarnation_id", tensor_incarnation_id)); 682 } 683 }; 684 685 REGISTER_KERNEL_BUILDER(Name("MultiDeviceIteratorInit").Device(DEVICE_CPU), 686 MultiDeviceIteratorInitOp); 687 688 // Calls GetNextFromShard(shard) and returns a vector of Tensors as output. 689 class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel { 690 public: MultiDeviceIteratorGetNextFromShardOp(OpKernelConstruction * ctx)691 explicit MultiDeviceIteratorGetNextFromShardOp(OpKernelConstruction* ctx) 692 : AsyncOpKernel(ctx), 693 background_worker_(ctx->env(), 694 "tf_data_multi_device_iterator_get_next") {} 695 ComputeAsync(OpKernelContext * ctx,DoneCallback done)696 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { 697 const Tensor* tensor_shard_num; 698 OP_REQUIRES_OK_ASYNC(ctx, ctx->input("shard_num", &tensor_shard_num), done); 699 int32_t shard_num = tensor_shard_num->scalar<int32>()(); 700 701 const Tensor* tensor_incarnation_id; 702 OP_REQUIRES_OK_ASYNC( 703 ctx, ctx->input("incarnation_id", &tensor_incarnation_id), done); 704 int64_t incarnation_id = tensor_incarnation_id->scalar<int64_t>()(); 705 706 MultiDeviceIterator* iterator; 707 OP_REQUIRES_OK_ASYNC( 708 ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done); 709 710 background_worker_.Schedule(std::bind( 711 [ctx, iterator, shard_num, incarnation_id](DoneCallback done) { 712 Notification n; 713 absl::Time start_time = iterator->metrics_collector().RecordStart(); 714 MultiDeviceIteratorCallback callback = std::bind( 715 [ctx, iterator, start_time, &n](const HostBufferElement& elem) { 716 iterator->metrics_collector().RecordStop(start_time, 717 elem.value); 718 Status s = elem.status; 719 if (!s.ok()) { 720 ctx->SetStatus(s); 721 } else if (elem.end_of_sequence) { 722 ctx->SetStatus(errors::OutOfRange("End of sequence")); 723 } else { 724 for (int i = 0; i < elem.value.size(); ++i) { 725 ctx->set_output(i, elem.value[i]); 726 } 727 } 728 n.Notify(); 729 }, 730 std::placeholders::_1); 731 732 Status s = iterator->GetNextFromShard(ctx, shard_num, incarnation_id, 733 std::move(callback)); 734 if (!s.ok()) { 735 ctx->SetStatus(s); 736 iterator->Unref(); 737 done(); 738 return; 739 } 740 iterator->Unref(); 741 n.WaitForNotification(); 742 done(); 743 }, 744 std::move(done))); 745 } 746 747 private: 748 BackgroundWorker background_worker_; 749 }; 750 751 REGISTER_KERNEL_BUILDER( 752 Name("MultiDeviceIteratorGetNextFromShard").Device(DEVICE_CPU), 753 MultiDeviceIteratorGetNextFromShardOp); 754 755 class MultiDeviceIteratorToStringHandleOp : public OpKernel { 756 public: MultiDeviceIteratorToStringHandleOp(OpKernelConstruction * ctx)757 explicit MultiDeviceIteratorToStringHandleOp(OpKernelConstruction* ctx) 758 : OpKernel(ctx) {} 759 Compute(OpKernelContext * ctx)760 void Compute(OpKernelContext* ctx) override { 761 const Tensor& resource_handle_t = ctx->input(0); 762 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()), 763 errors::InvalidArgument("resource_handle must be a scalar")); 764 765 // Validate that the handle corresponds to a real resource, and 766 // that it is an MultiDeviceIterator. 767 core::RefCountPtr<MultiDeviceIterator> resource; 768 OP_REQUIRES_OK(ctx, 769 LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); 770 771 Tensor* string_handle_t; 772 OP_REQUIRES_OK(ctx, 773 ctx->allocate_output(0, TensorShape({}), &string_handle_t)); 774 string_handle_t->scalar<tstring>()() = 775 resource_handle_t.scalar<ResourceHandle>()().SerializeAsString(); 776 } 777 }; 778 779 REGISTER_KERNEL_BUILDER( 780 Name("MultiDeviceIteratorToStringHandle").Device(DEVICE_CPU), 781 MultiDeviceIteratorToStringHandleOp); 782 783 class MultiDeviceIteratorFromStringHandleOp : public OpKernel { 784 public: MultiDeviceIteratorFromStringHandleOp(OpKernelConstruction * ctx)785 explicit MultiDeviceIteratorFromStringHandleOp(OpKernelConstruction* ctx) 786 : OpKernel(ctx) { 787 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_)); 788 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_)); 789 OP_REQUIRES( 790 ctx, 791 output_types_.empty() || output_shapes_.empty() || 792 output_types_.size() == output_shapes_.size(), 793 errors::InvalidArgument("If both 'output_types' and 'output_shapes' " 794 "are set, they must have the same length.")); 795 } 796 Compute(OpKernelContext * ctx)797 void Compute(OpKernelContext* ctx) override { 798 const Tensor& string_handle_t = ctx->input(0); 799 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()), 800 errors::InvalidArgument("string_handle must be a scalar")); 801 802 ResourceHandle resource_handle; 803 OP_REQUIRES( 804 ctx, 805 resource_handle.ParseFromString(string_handle_t.scalar<tstring>()()), 806 errors::InvalidArgument( 807 "Could not parse string_handle as a valid ResourceHandle")); 808 809 OP_REQUIRES( 810 ctx, resource_handle.device() == ctx->device()->attributes().name(), 811 errors::InvalidArgument("Attempted create an iterator on device \"", 812 ctx->device()->attributes().name(), 813 "\" from handle defined on device \"", 814 resource_handle.device(), "\"")); 815 816 // Validate that the handle corresponds to a real resource, and 817 // that it is an MultiDeviceIterator. 818 core::RefCountPtr<MultiDeviceIterator> resource; 819 OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &resource)); 820 if (!output_types_.empty()) { 821 OP_REQUIRES_OK(ctx, 822 VerifyTypesMatch(output_types_, resource->output_types())); 823 } 824 if (!output_shapes_.empty()) { 825 OP_REQUIRES_OK(ctx, VerifyShapesCompatible(output_shapes_, 826 resource->output_shapes())); 827 } 828 829 Tensor* resource_handle_t; 830 OP_REQUIRES_OK( 831 ctx, ctx->allocate_output(0, TensorShape({}), &resource_handle_t)); 832 resource_handle_t->scalar<ResourceHandle>()() = resource_handle; 833 } 834 835 private: 836 DataTypeVector output_types_; 837 std::vector<PartialTensorShape> output_shapes_; 838 }; 839 840 REGISTER_KERNEL_BUILDER( 841 Name("MultiDeviceIteratorFromStringHandle").Device(DEVICE_CPU), 842 MultiDeviceIteratorFromStringHandleOp); 843 844 class DeleteMultiDeviceIteratorOp : public OpKernel { 845 public: DeleteMultiDeviceIteratorOp(OpKernelConstruction * ctx)846 explicit DeleteMultiDeviceIteratorOp(OpKernelConstruction* ctx) 847 : OpKernel(ctx) {} 848 Compute(OpKernelContext * ctx)849 void Compute(OpKernelContext* ctx) override { 850 ResourceHandle handle = ctx->input(0).flat<ResourceHandle>()(0); 851 // The iterator resource is guaranteed to 852 // exist because the variant tensor wrapping the deleter is provided as an 853 // unused input to this op, which guarantees that it has not run yet. 854 OP_REQUIRES_OK(ctx, DeleteResource(ctx, handle)); 855 } 856 }; 857 858 REGISTER_KERNEL_BUILDER(Name("DeleteMultiDeviceIterator").Device(DEVICE_CPU), 859 DeleteMultiDeviceIteratorOp); 860 // Since this op takes in Iterator handles as (unused) inputs, we don't want 861 // to constrain the iterator location to CPU only. Therefore, we exempt the 862 // colocation restriction for this op allowing the iterators to be placed on 863 // other devices. 864 REGISTER_INPUT_COLOCATION_EXEMPTION("DeleteMultiDeviceIterator"); 865 866 } // namespace 867 } // namespace data 868 } // namespace tensorflow 869