xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/data/multi_device_iterator_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <atomic>
16 #include <cstdint>
17 #include <deque>
18 #include <functional>
19 #include <utility>
20 
21 #include "absl/time/time.h"
22 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
23 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
24 #include "tensorflow/core/data/dataset_utils.h"
25 #include "tensorflow/core/data/metric_utils.h"
26 #include "tensorflow/core/data/root_dataset.h"
27 #include "tensorflow/core/data/unbounded_thread_pool.h"
28 #include "tensorflow/core/framework/cancellation.h"
29 #include "tensorflow/core/framework/dataset.h"
30 #include "tensorflow/core/framework/function.h"
31 #include "tensorflow/core/framework/function_handle_cache.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/resource_op_kernel.h"
34 #include "tensorflow/core/framework/types.h"
35 #include "tensorflow/core/kernels/data/iterator_ops.h"
36 #include "tensorflow/core/kernels/ops_util.h"
37 #include "tensorflow/core/lib/core/refcount.h"
38 #include "tensorflow/core/lib/gtl/cleanup.h"
39 #include "tensorflow/core/lib/random/random.h"
40 #include "tensorflow/core/platform/env.h"
41 #include "tensorflow/core/platform/refcount.h"
42 #include "tensorflow/core/util/device_name_utils.h"
43 
44 namespace tensorflow {
45 namespace data {
46 namespace {
47 
48 const char kAnonymousMultiDeviceIterator[] = "AnonymousMultiDeviceIterator";
49 const char kAnonymousMultiDeviceIteratorV3[] = "AnonymousMultiDeviceIteratorV3";
50 const char kDevices[] = "devices";
51 const char kOutputShapes[] = "output_shapes";
52 const char kOutputTypes[] = "output_types";
53 
54 struct HostBufferElement {
55   Status status;
56   bool end_of_sequence;
57   std::vector<Tensor> value;
58 };
59 
60 using MultiDeviceIteratorCallback =
61     std::function<void(const HostBufferElement&)>;
62 
63 // MultiDeviceIterator provides the ability for multiple devices to fetch from
64 // one iterator in a roundrobin sequence, which is deterministic. This means
65 // that, for exmaple, starting from the beginning GetNextFromShard(0) always
66 // gets the first element and GetNextFromShard(1) always gets the second
67 // element, even if GetNextFromShard(1) is called before GetNextFromShard(0).
68 //
69 // Note on cancellation:
70 //   * MultiDeviceIterator can be cancelled as a whole by calling Reset() or
71 //   cancel MultiDeviceIterator::cancellation_manager().
72 //   * GetNextFromShard can be cancelled independently. Cancelling
73 //   GetNextFromShard for one shard doesn't cancel the underlying prefetching,
74 //   nor does it other calls of GetNextFromShard.
75 class MultiDeviceIterator : public ResourceBase {
76  public:
MultiDeviceIterator(Env * env,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes,const std::vector<string> & devices,std::unique_ptr<FunctionLibraryDefinition> flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * flr,std::unique_ptr<FunctionHandleCache> function_handle_cache)77   MultiDeviceIterator(
78       Env* env, const DataTypeVector& output_types,
79       const std::vector<PartialTensorShape>& output_shapes,
80       const std::vector<string>& devices,
81       std::unique_ptr<FunctionLibraryDefinition> flib_def,
82       std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
83       FunctionLibraryRuntime* flr,
84       std::unique_ptr<FunctionHandleCache> function_handle_cache)
85       : metrics_collector_(flr ? flr->device()->device_type() : DEVICE_DEFAULT,
86                            *env),
87         unbounded_thread_pool_(env, "tf_data_multi_device_iterator_resource"),
88         output_types_(output_types),
89         output_shapes_(output_shapes),
90         devices_(devices),
91         flib_def_(std::move(flib_def)),
92         flr_(flr),
93         pflr_(std::move(pflr)),
94         function_handle_cache_(std::move(function_handle_cache)) {
95     DCHECK(flr_ != nullptr);
96     VLOG(2) << "Creating multi-device iterator.";
97   }
98 
~MultiDeviceIterator()99   ~MultiDeviceIterator() override {
100     VLOG(2) << "Destroying multi-device iterator.";
101   }
102 
DebugString() const103   string DebugString() const override {
104     return strings::StrCat("MultiDeviceIterator for ", devices_.size(),
105                            " devices");
106   }
107 
Init(std::unique_ptr<IteratorBase> iterator,int64_t max_buffer_size,int64_t * incarnation_id,DatasetBase * dataset)108   Status Init(std::unique_ptr<IteratorBase> iterator, int64_t max_buffer_size,
109               int64_t* incarnation_id, DatasetBase* dataset) {
110     if (iterator) {
111       TF_RETURN_IF_ERROR(
112           VerifyTypesMatch(output_types_, iterator->output_dtypes()));
113       TF_RETURN_IF_ERROR(
114           VerifyShapesCompatible(output_shapes_, iterator->output_shapes()));
115     }
116 
117     mutex_lock l(mu_);
118     if (multi_device_buffer_) {
119       multi_device_buffer_->Reset();
120     }
121     dataset->Ref();
122     dataset_.reset(dataset);
123 
124     ++incarnation_id_;
125     *incarnation_id = incarnation_id_;
126 
127     multi_device_buffer_ = std::make_unique<MultiDeviceBuffer>(
128         devices_.size(), max_buffer_size, incarnation_id_, std::move(iterator),
129         this);
130     return OkStatus();
131   }
132 
GetNextFromShard(OpKernelContext * ctx,int shard_num,int64_t incarnation_id,MultiDeviceIteratorCallback callback)133   Status GetNextFromShard(OpKernelContext* ctx, int shard_num,
134                           int64_t incarnation_id,
135                           MultiDeviceIteratorCallback callback) {
136     tf_shared_lock l(mu_);
137     IteratorContext::Params params(ctx);
138     params.flr = flr_;
139     params.function_handle_cache = function_handle_cache_.get();
140     params.resource_mgr = &resource_mgr_;
141     params.thread_factory = unbounded_thread_pool_.get_thread_factory();
142     params.thread_pool = &unbounded_thread_pool_;
143     params.cancellation_manager = ctx->cancellation_manager();
144     IteratorContext iter_ctx(std::move(params));
145     multi_device_buffer_->GetNextFromShard(&iter_ctx, shard_num, incarnation_id,
146                                            std::move(callback));
147     return OkStatus();
148   }
149 
output_types() const150   const DataTypeVector& output_types() const { return output_types_; }
151 
output_shapes() const152   const std::vector<PartialTensorShape>& output_shapes() const {
153     return output_shapes_;
154   }
155 
flr()156   FunctionLibraryRuntime* const flr() {
157     tf_shared_lock l(mu_);
158     return flr_;
159   }
160 
function_handle_cache()161   FunctionHandleCache* function_handle_cache() {
162     return function_handle_cache_.get();
163   }
164 
resource_mgr()165   ResourceMgr* resource_mgr() { return &resource_mgr_; }
166 
cancellation_manager()167   CancellationManager* cancellation_manager() { return &cancellation_manager_; }
168 
metrics_collector()169   IteratorMetricsCollector& metrics_collector() { return metrics_collector_; }
170 
171  private:
172   // A private class that uses a background thread to keep a per device buffer
173   // full.
174   class MultiDeviceBuffer {
175    public:
MultiDeviceBuffer(size_t size,int64_t max_buffer_size,int64_t incarnation_id,std::unique_ptr<IteratorBase> host_iterator,MultiDeviceIterator * parent)176     MultiDeviceBuffer(size_t size, int64_t max_buffer_size,
177                       int64_t incarnation_id,
178                       std::unique_ptr<IteratorBase> host_iterator,
179                       MultiDeviceIterator* parent)
180         : buffer_(size),
181           size_(size),
182           max_buffer_size_(max_buffer_size),
183           incarnation_id_(incarnation_id),
184           host_iterator_(std::move(host_iterator)),
185           parent_(parent) {}
186 
~MultiDeviceBuffer()187     ~MultiDeviceBuffer() {
188       {
189         mutex_lock l(mu_);
190         if (!background_thread_started_) return;
191       }
192       Reset();
193     }
194 
Reset()195     void Reset() TF_LOCKS_EXCLUDED(mu_) {
196       {
197         mutex_lock l(mu_);
198         if (background_thread_ && !background_thread_finished_) {
199           cancellation_manager_.StartCancel();
200           // Wake up the background thread.
201           for (int i = 0; i < size_; ++i) {
202             buffer_[i].cond_var.notify_all();
203           }
204 
205           // Make sure background thread has finished first.
206           while (!background_thread_finished_) {
207             shutdown_cond_var_.wait(l);
208           }
209         }
210       }
211       RunPendingCallbacks();
212     }
213 
GetNextFromShard(IteratorContext * ctx,int shard_num,int64_t incarnation_id,MultiDeviceIteratorCallback callback)214     void GetNextFromShard(IteratorContext* ctx, int shard_num,
215                           int64_t incarnation_id,
216                           MultiDeviceIteratorCallback callback) {
217       HostBufferElement elem;
218       if (incarnation_id_ != incarnation_id) {
219         elem.status = errors::InvalidArgument(
220             "Invalid incarnation id. Provided: ", incarnation_id,
221             "; Expected: ", incarnation_id_);
222         callback(elem);
223         return;
224       }
225 
226       bool produced_output = false;
227       {
228         mutex_lock l(mu_);
229         if (cancellation_manager_.IsCancelled()) {
230           elem.status = errors::Cancelled("Cancelled Multidevice iterator");
231           callback(elem);
232           return;
233         }
234 
235         EnsureBackgroundThreadStarted(ctx);
236 
237         if (!buffer_[shard_num].data.empty()) {
238           produced_output = true;
239           std::swap(elem, buffer_[shard_num].data.front());
240           buffer_[shard_num].data.pop_front();
241           // Wake up background thread if it is blocked on this element.
242           if (buffer_[shard_num].data.size() == max_buffer_size_ - 1) {
243             buffer_[shard_num].cond_var.notify_all();
244           }
245         } else {
246           if (end_of_iterator_) {
247             produced_output = true;
248             elem.end_of_sequence = true;
249           } else {
250             auto callback_container =
251                 std::make_shared<HostBuffer::CallbackContainer>(
252                     std::move(callback));
253             elem.status = RegisterCancellationCallback(
254                 ctx->cancellation_manager(),
255                 [callback_container]() {
256                   if (callback_container->is_called.exchange(true)) {
257                     return;
258                   }
259                   HostBufferElement elem;
260                   elem.status =
261                       errors::Cancelled("GetNextFromShard was cancelled");
262                   callback_container->callback(elem);
263                 },
264                 &callback_container->deregister_cancellation);
265             if (!elem.status.ok()) {
266               callback_container->callback(elem);
267               return;
268             }
269             buffer_[shard_num].callbacks.push_back(
270                 std::move(callback_container));
271             buffer_[shard_num].cond_var.notify_all();
272             callback = nullptr;
273           }
274         }
275       }
276 
277       if (produced_output) {
278         callback(elem);
279       }
280     }
281 
282    private:
EnsureBackgroundThreadStarted(IteratorContext * ctx)283     void EnsureBackgroundThreadStarted(IteratorContext* ctx)
284         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
285       if (!background_thread_) {
286         IteratorContext::Params params(ctx);
287         params.cancellation_manager = &cancellation_manager_;
288         background_thread_ =
289             parent_->unbounded_thread_pool_.get_thread_factory()->StartThread(
290                 "tf_data_multi_device_iterator",
291                 std::bind(
292                     &MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread,
293                     this,
294                     std::make_shared<IteratorContext>(std::move(params))));
295       }
296     }
297 
RunPendingCallbacks()298     void RunPendingCallbacks() TF_LOCKS_EXCLUDED(mu_) {
299       // Run all remaining callbacks.
300 
301       std::vector<std::shared_ptr<HostBuffer::CallbackContainer>>
302           callback_containers;
303       std::vector<HostBufferElement> cancellation_elements;
304       {
305         mutex_lock l(mu_);
306 
307         for (int i = 0; i < size_; ++i) {
308           while (!buffer_[i].callbacks.empty()) {
309             if (buffer_[i].callbacks.front()->is_called.exchange(true)) {
310               buffer_[i].callbacks.pop_front();
311               continue;
312             }
313             if (buffer_[i].data.empty()) {
314               HostBufferElement elem;
315               if (end_of_iterator_) {
316                 elem.end_of_sequence = true;
317               } else {
318                 elem.status =
319                     errors::Cancelled("Cancelled and buffer not filled.");
320               }
321               cancellation_elements.push_back(std::move(elem));
322             } else {
323               cancellation_elements.push_back(
324                   std::move(buffer_[i].data.front()));
325               buffer_[i].data.pop_front();
326             }
327             callback_containers.push_back(
328                 std::move(buffer_[i].callbacks.front()));
329             buffer_[i].callbacks.pop_front();
330           }
331         }
332       }
333       for (int i = 0; i < callback_containers.size(); ++i) {
334         if (callback_containers[i]->deregister_cancellation != nullptr) {
335           callback_containers[i]->deregister_cancellation();
336         }
337         // We invoke the callback regardless of whether deregistration succeeds
338         // or not, because we have set is_called=true previous which effectively
339         // disables the cancellation callback.
340         callback_containers[i]->callback(cancellation_elements[i]);
341       }
342     }
343 
BackgroundThread(std::shared_ptr<IteratorContext> ctx)344     void BackgroundThread(std::shared_ptr<IteratorContext> ctx) {
345       {
346         mutex_lock l(mu_);
347         background_thread_started_ = true;
348       }
349       int shard_to_fetch = 0;
350       while (true) {
351         HostBufferElement elem;
352         bool end_of_iterator = false;
353 
354         {
355           mutex_lock l(mu_);
356           while (!cancellation_manager_.IsCancelled() &&
357                  buffer_[shard_to_fetch].data.size() >= max_buffer_size_ &&
358                  buffer_[shard_to_fetch].callbacks.empty()) {
359             buffer_[shard_to_fetch].cond_var.wait(l);
360           }
361 
362           if (cancellation_manager_.IsCancelled()) {
363             background_thread_finished_ = true;
364             shutdown_cond_var_.notify_all();
365             return;
366           }
367         }
368 
369         elem.status = host_iterator_->GetNext(ctx.get(), &elem.value,
370                                               &elem.end_of_sequence);
371 
372         if (elem.status.ok() && elem.end_of_sequence) {
373           end_of_iterator = true;
374         }
375 
376         std::shared_ptr<HostBuffer::CallbackContainer> callback_container;
377         {
378           mutex_lock l(mu_);
379           // Try to find a callback, else just push stuff into buffer.
380           if (!buffer_[shard_to_fetch].callbacks.empty()) {
381             while (!buffer_[shard_to_fetch].callbacks.empty()) {
382               if (buffer_[shard_to_fetch].callbacks.front()->is_called.exchange(
383                       true)) {
384                 // This callback is already cancelled.
385                 buffer_[shard_to_fetch].callbacks.pop_front();
386                 continue;
387               } else {
388                 callback_container =
389                     std::move(buffer_[shard_to_fetch].callbacks.front());
390                 buffer_[shard_to_fetch].callbacks.pop_front();
391                 break;
392               }
393             }
394           } else {
395             buffer_[shard_to_fetch].data.push_back(std::move(elem));
396             elem = HostBufferElement();
397           }
398         }
399 
400         if (callback_container) {
401           if (callback_container->deregister_cancellation != nullptr) {
402             callback_container->deregister_cancellation();
403           }
404           (*ctx->runner())(std::bind(std::move(callback_container->callback),
405                                      std::move(elem)));
406         }
407 
408         // Finish off the thread if we reach the end of the iterator. Runs
409         // pending callbacks.
410         if (end_of_iterator) {
411           {
412             mutex_lock l(mu_);
413             background_thread_finished_ = true;
414             end_of_iterator_ = true;
415             shutdown_cond_var_.notify_all();
416           }
417           RunPendingCallbacks();
418           return;
419         }
420         shard_to_fetch = (shard_to_fetch + 1) % size_;
421       }
422     }
423 
424     struct HostBuffer {
425       condition_variable cond_var;
426       std::deque<HostBufferElement> data;
427       struct CallbackContainer {
428         MultiDeviceIteratorCallback callback;
429         // Whether callback is already called, either by the background thread
430         // of by the cancellation callback.
431         std::atomic<bool> is_called;
432         std::function<void()> deregister_cancellation;
CallbackContainertensorflow::data::__anon7578bfd70111::MultiDeviceIterator::MultiDeviceBuffer::HostBuffer::CallbackContainer433         explicit CallbackContainer(MultiDeviceIteratorCallback&& callback)
434             : callback(std::move(callback)), is_called(false) {}
435       };
436       // The CallbackContainer is shared with the cancellation callback.
437       std::deque<std::shared_ptr<CallbackContainer>> callbacks;
438     };
439 
440     mutex mu_;
441     bool background_thread_finished_ TF_GUARDED_BY(mu_) = false;
442     bool background_thread_started_ TF_GUARDED_BY(mu_) = false;
443     bool end_of_iterator_ TF_GUARDED_BY(mu_) = false;
444     condition_variable shutdown_cond_var_ TF_GUARDED_BY(mu_);
445 
446     std::vector<HostBuffer> buffer_;
447 
448     const size_t size_;
449     const int64_t max_buffer_size_;
450     const int64_t incarnation_id_;
451     CancellationManager cancellation_manager_;
452     const std::unique_ptr<IteratorBase> host_iterator_;
453     MultiDeviceIterator* const parent_;  // Not owned.
454     std::unique_ptr<Thread> background_thread_ TF_GUARDED_BY(mu_);
455   };
456 
457   IteratorMetricsCollector metrics_collector_;
458   UnboundedThreadPool unbounded_thread_pool_;
459 
460   mutex mu_;
461   const DataTypeVector output_types_;
462   const std::vector<PartialTensorShape> output_shapes_;
463   const std::vector<string> devices_;
464   const std::unique_ptr<FunctionLibraryDefinition> flib_def_;
465   FunctionLibraryRuntime* const flr_ = nullptr;  // not owned.
466   const std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
467   const std::unique_ptr<FunctionHandleCache> function_handle_cache_;
468   ResourceMgr resource_mgr_;
469   CancellationManager cancellation_manager_;
470 
471   int64_t incarnation_id_ TF_GUARDED_BY(mu_) = 0;
472   std::unique_ptr<MultiDeviceBuffer> multi_device_buffer_ TF_GUARDED_BY(mu_);
473   core::RefCountPtr<DatasetBase> dataset_;
474 };
475 
476 // Used to generate unique names for anonymous multi device iterators.
477 static std::atomic<int64_t> current_id_;
478 
479 // Just creates a MultiDeviceIterator and returns it.
480 class MultiDeviceIteratorHandleOp : public OpKernel {
481  public:
MultiDeviceIteratorHandleOp(OpKernelConstruction * ctx)482   explicit MultiDeviceIteratorHandleOp(OpKernelConstruction* ctx)
483       : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
484     OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
485     OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
486     OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
487     OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_));
488     OP_REQUIRES_OK(ctx, ctx->GetAttr(kDevices, &devices_));
489   }
490 
491   // The resource is deleted from the resource manager only when it is private
492   // to kernel.
~MultiDeviceIteratorHandleOp()493   ~MultiDeviceIteratorHandleOp() override {
494     if (resource_ != nullptr) {
495       resource_->Unref();
496       if (cinfo_.resource_is_private_to_kernel()) {
497         if (!cinfo_.resource_manager()
498                  ->template Delete<MultiDeviceIterator>(cinfo_.container(),
499                                                         cinfo_.name())
500                  .ok()) {
501           // Do nothing; the resource can have been deleted by session resets.
502         }
503       }
504     }
505   }
506 
Compute(OpKernelContext * context)507   void Compute(OpKernelContext* context) override TF_LOCKS_EXCLUDED(mu_) {
508     string unique_name = cinfo_.name();
509     string container_name = cinfo_.container();
510     {
511       mutex_lock l(mu_);
512       if (resource_ == nullptr) {
513         FunctionLibraryRuntime* flr;
514         std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
515         std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
516         OP_REQUIRES_OK(context, context->function_library()->Clone(
517                                     &flib_def, &pflr, &flr));
518         auto function_handle_cache = std::make_unique<FunctionHandleCache>(flr);
519         ResourceMgr* mgr = context->resource_manager();
520         OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
521 
522         MultiDeviceIterator* resource;
523 
524         if (name_ == ResourceHandle::ANONYMOUS_NAME) {
525           unique_name = strings::StrCat("_AnonymousMultiDeviceIterator",
526                                         current_id_.fetch_add(1));
527           container_name = kAnonymousMultiDeviceIterator;
528           resource = new MultiDeviceIterator(
529               context->env(), output_types_, output_shapes_, devices_,
530               std::move(flib_def), std::move(pflr), flr,
531               std::move(function_handle_cache));
532           // NOTE: `mgr->Create()` transfers the one reference on `resource` to
533           // `mgr`.
534           OP_REQUIRES_OK(context, mgr->Create<MultiDeviceIterator>(
535                                       container_name, unique_name, resource));
536         } else {
537           unique_name = cinfo_.name();
538           container_name = cinfo_.container();
539           OP_REQUIRES_OK(context,
540                          mgr->LookupOrCreate<MultiDeviceIterator>(
541                              container_name, unique_name, &resource,
542                              [this, context, flr, &flib_def, &pflr,
543                               &function_handle_cache](MultiDeviceIterator** ret)
544                                  TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
545                                    *ret = new MultiDeviceIterator(
546                                        context->env(), output_types_,
547                                        output_shapes_, devices_,
548                                        std::move(flib_def), std::move(pflr),
549                                        flr, std::move(function_handle_cache));
550                                    return OkStatus();
551                                  }));
552           Status s = VerifyResource(resource);
553           if (TF_PREDICT_FALSE(!s.ok())) {
554             resource->Unref();
555             context->SetStatus(s);
556             return;
557           }
558           resource_ = resource;
559         }
560       }
561     }
562     OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
563                                 context, 0, container_name, unique_name,
564                                 TypeIndex::Make<MultiDeviceIterator>()));
565   }
566 
567  private:
568   // During the first Compute(), resource is either created or looked up using
569   // shared_name. In the latter case, the resource found should be verified if
570   // it is compatible with this op's configuration. The verification may fail in
571   // cases such as two graphs asking queues of the same shared name to have
572   // inconsistent capacities.
VerifyResource(MultiDeviceIterator * resource)573   Status VerifyResource(MultiDeviceIterator* resource) {
574     TF_RETURN_IF_ERROR(
575         VerifyTypesMatch(output_types_, resource->output_types()));
576     TF_RETURN_IF_ERROR(
577         VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
578     return OkStatus();
579   }
580 
581   mutex mu_;
582   ContainerInfo cinfo_;  // Written once under mu_ then constant afterwards.
583   MultiDeviceIterator* resource_ TF_GUARDED_BY(mu_) = nullptr;
584   DataTypeVector output_types_;
585   std::vector<PartialTensorShape> output_shapes_;
586   const int graph_def_version_;
587   string name_;
588   string container_;
589   std::vector<string> devices_;
590 };
591 
592 REGISTER_KERNEL_BUILDER(Name("MultiDeviceIterator").Device(DEVICE_CPU),
593                         MultiDeviceIteratorHandleOp);
594 
595 class AnonymousMultiDeviceIteratorOp
596     : public AnonymousResourceOp<MultiDeviceIterator> {
597  public:
AnonymousMultiDeviceIteratorOp(OpKernelConstruction * ctx)598   explicit AnonymousMultiDeviceIteratorOp(OpKernelConstruction* ctx)
599       : AnonymousResourceOp<MultiDeviceIterator>(
600             ctx,
601             /* ref_counting */ true,
602             /* Only V1 returns a deleter */
603             /* return_deleter */
604             ctx->def().op() == kAnonymousMultiDeviceIterator) {
605     OP_REQUIRES_OK(ctx, ctx->GetAttr(kDevices, &devices_));
606     OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_dtypes_));
607     OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
608   }
609 
610  private:
name()611   string name() override { return kAnonymousMultiDeviceIterator; }
612 
CreateResource(OpKernelContext * ctx,std::unique_ptr<FunctionLibraryDefinition> flib_def,std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,FunctionLibraryRuntime * lib,MultiDeviceIterator ** resource)613   Status CreateResource(OpKernelContext* ctx,
614                         std::unique_ptr<FunctionLibraryDefinition> flib_def,
615                         std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
616                         FunctionLibraryRuntime* lib,
617                         MultiDeviceIterator** resource) override {
618     auto function_handle_cache = std::make_unique<FunctionHandleCache>(lib);
619     *resource =
620         new MultiDeviceIterator(ctx->env(), output_dtypes_, output_shapes_,
621                                 devices_, std::move(flib_def), std::move(pflr),
622                                 lib, std::move(function_handle_cache));
623     return OkStatus();
624   }
625 
626   std::vector<string> devices_;
627   DataTypeVector output_dtypes_;
628   std::vector<PartialTensorShape> output_shapes_;
629 };
630 
631 REGISTER_KERNEL_BUILDER(Name(kAnonymousMultiDeviceIterator).Device(DEVICE_CPU),
632                         AnonymousMultiDeviceIteratorOp);
633 REGISTER_KERNEL_BUILDER(
634     Name(kAnonymousMultiDeviceIteratorV3).Device(DEVICE_CPU),
635     AnonymousMultiDeviceIteratorOp);
636 
637 // Calls init on the MultiDeviceIterator.
638 class MultiDeviceIteratorInitOp : public OpKernel {
639  public:
MultiDeviceIteratorInitOp(OpKernelConstruction * ctx)640   explicit MultiDeviceIteratorInitOp(OpKernelConstruction* ctx)
641       : OpKernel(ctx) {}
642 
Compute(OpKernelContext * ctx)643   void Compute(OpKernelContext* ctx) override {
644     const Tensor* tensor_max_buffer_size;
645     OP_REQUIRES_OK(ctx, ctx->input("max_buffer_size", &tensor_max_buffer_size));
646     int64_t max_buffer_size = tensor_max_buffer_size->scalar<int64_t>()();
647 
648     DatasetBase* dataset;
649     OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
650     core::RefCountPtr<MultiDeviceIterator> resource;
651     OP_REQUIRES_OK(ctx,
652                    LookupResource(ctx, HandleFromInput(ctx, 1), &resource));
653 
654     IteratorContext::Params params(ctx);
655     params.flr = resource->flr();
656     params.function_handle_cache = resource->function_handle_cache();
657     params.resource_mgr = resource->resource_mgr();
658     params.cancellation_manager = resource->cancellation_manager();
659     std::function<void()> deregister_fn;
660     OP_REQUIRES_OK(
661         ctx, RegisterCancellationCallback(
662                  ctx->cancellation_manager(),
663                  [cm = params.cancellation_manager]() { cm->StartCancel(); },
664                  &deregister_fn));
665     auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
666     IteratorContext iter_ctx(std::move(params));
667 
668     std::unique_ptr<IteratorBase> iterator;
669     DatasetBase* finalized_dataset;
670     OP_REQUIRES_OK(ctx, FinalizeDataset(ctx, dataset, &finalized_dataset));
671     OP_REQUIRES_OK(ctx, finalized_dataset->MakeIterator(std::move(iter_ctx),
672                                                         /*parent=*/nullptr,
673                                                         "Iterator", &iterator));
674     core::ScopedUnref unref(finalized_dataset);
675     int64_t incarnation_id;
676     OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size,
677                                        &incarnation_id, dataset));
678     Tensor tensor_incarnation_id(DT_INT64, TensorShape({}));
679     tensor_incarnation_id.scalar<int64_t>()() = incarnation_id;
680     OP_REQUIRES_OK(ctx,
681                    ctx->set_output("incarnation_id", tensor_incarnation_id));
682   }
683 };
684 
685 REGISTER_KERNEL_BUILDER(Name("MultiDeviceIteratorInit").Device(DEVICE_CPU),
686                         MultiDeviceIteratorInitOp);
687 
688 // Calls GetNextFromShard(shard) and returns a vector of Tensors as output.
689 class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel {
690  public:
MultiDeviceIteratorGetNextFromShardOp(OpKernelConstruction * ctx)691   explicit MultiDeviceIteratorGetNextFromShardOp(OpKernelConstruction* ctx)
692       : AsyncOpKernel(ctx),
693         background_worker_(ctx->env(),
694                            "tf_data_multi_device_iterator_get_next") {}
695 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)696   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
697     const Tensor* tensor_shard_num;
698     OP_REQUIRES_OK_ASYNC(ctx, ctx->input("shard_num", &tensor_shard_num), done);
699     int32_t shard_num = tensor_shard_num->scalar<int32>()();
700 
701     const Tensor* tensor_incarnation_id;
702     OP_REQUIRES_OK_ASYNC(
703         ctx, ctx->input("incarnation_id", &tensor_incarnation_id), done);
704     int64_t incarnation_id = tensor_incarnation_id->scalar<int64_t>()();
705 
706     MultiDeviceIterator* iterator;
707     OP_REQUIRES_OK_ASYNC(
708         ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
709 
710     background_worker_.Schedule(std::bind(
711         [ctx, iterator, shard_num, incarnation_id](DoneCallback done) {
712           Notification n;
713           absl::Time start_time = iterator->metrics_collector().RecordStart();
714           MultiDeviceIteratorCallback callback = std::bind(
715               [ctx, iterator, start_time, &n](const HostBufferElement& elem) {
716                 iterator->metrics_collector().RecordStop(start_time,
717                                                          elem.value);
718                 Status s = elem.status;
719                 if (!s.ok()) {
720                   ctx->SetStatus(s);
721                 } else if (elem.end_of_sequence) {
722                   ctx->SetStatus(errors::OutOfRange("End of sequence"));
723                 } else {
724                   for (int i = 0; i < elem.value.size(); ++i) {
725                     ctx->set_output(i, elem.value[i]);
726                   }
727                 }
728                 n.Notify();
729               },
730               std::placeholders::_1);
731 
732           Status s = iterator->GetNextFromShard(ctx, shard_num, incarnation_id,
733                                                 std::move(callback));
734           if (!s.ok()) {
735             ctx->SetStatus(s);
736             iterator->Unref();
737             done();
738             return;
739           }
740           iterator->Unref();
741           n.WaitForNotification();
742           done();
743         },
744         std::move(done)));
745   }
746 
747  private:
748   BackgroundWorker background_worker_;
749 };
750 
751 REGISTER_KERNEL_BUILDER(
752     Name("MultiDeviceIteratorGetNextFromShard").Device(DEVICE_CPU),
753     MultiDeviceIteratorGetNextFromShardOp);
754 
755 class MultiDeviceIteratorToStringHandleOp : public OpKernel {
756  public:
MultiDeviceIteratorToStringHandleOp(OpKernelConstruction * ctx)757   explicit MultiDeviceIteratorToStringHandleOp(OpKernelConstruction* ctx)
758       : OpKernel(ctx) {}
759 
Compute(OpKernelContext * ctx)760   void Compute(OpKernelContext* ctx) override {
761     const Tensor& resource_handle_t = ctx->input(0);
762     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
763                 errors::InvalidArgument("resource_handle must be a scalar"));
764 
765     // Validate that the handle corresponds to a real resource, and
766     // that it is an MultiDeviceIterator.
767     core::RefCountPtr<MultiDeviceIterator> resource;
768     OP_REQUIRES_OK(ctx,
769                    LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
770 
771     Tensor* string_handle_t;
772     OP_REQUIRES_OK(ctx,
773                    ctx->allocate_output(0, TensorShape({}), &string_handle_t));
774     string_handle_t->scalar<tstring>()() =
775         resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
776   }
777 };
778 
779 REGISTER_KERNEL_BUILDER(
780     Name("MultiDeviceIteratorToStringHandle").Device(DEVICE_CPU),
781     MultiDeviceIteratorToStringHandleOp);
782 
783 class MultiDeviceIteratorFromStringHandleOp : public OpKernel {
784  public:
MultiDeviceIteratorFromStringHandleOp(OpKernelConstruction * ctx)785   explicit MultiDeviceIteratorFromStringHandleOp(OpKernelConstruction* ctx)
786       : OpKernel(ctx) {
787     OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
788     OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
789     OP_REQUIRES(
790         ctx,
791         output_types_.empty() || output_shapes_.empty() ||
792             output_types_.size() == output_shapes_.size(),
793         errors::InvalidArgument("If both 'output_types' and 'output_shapes' "
794                                 "are set, they must have the same length."));
795   }
796 
Compute(OpKernelContext * ctx)797   void Compute(OpKernelContext* ctx) override {
798     const Tensor& string_handle_t = ctx->input(0);
799     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()),
800                 errors::InvalidArgument("string_handle must be a scalar"));
801 
802     ResourceHandle resource_handle;
803     OP_REQUIRES(
804         ctx,
805         resource_handle.ParseFromString(string_handle_t.scalar<tstring>()()),
806         errors::InvalidArgument(
807             "Could not parse string_handle as a valid ResourceHandle"));
808 
809     OP_REQUIRES(
810         ctx, resource_handle.device() == ctx->device()->attributes().name(),
811         errors::InvalidArgument("Attempted create an iterator on device \"",
812                                 ctx->device()->attributes().name(),
813                                 "\" from handle defined on device \"",
814                                 resource_handle.device(), "\""));
815 
816     // Validate that the handle corresponds to a real resource, and
817     // that it is an MultiDeviceIterator.
818     core::RefCountPtr<MultiDeviceIterator> resource;
819     OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &resource));
820     if (!output_types_.empty()) {
821       OP_REQUIRES_OK(ctx,
822                      VerifyTypesMatch(output_types_, resource->output_types()));
823     }
824     if (!output_shapes_.empty()) {
825       OP_REQUIRES_OK(ctx, VerifyShapesCompatible(output_shapes_,
826                                                  resource->output_shapes()));
827     }
828 
829     Tensor* resource_handle_t;
830     OP_REQUIRES_OK(
831         ctx, ctx->allocate_output(0, TensorShape({}), &resource_handle_t));
832     resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
833   }
834 
835  private:
836   DataTypeVector output_types_;
837   std::vector<PartialTensorShape> output_shapes_;
838 };
839 
840 REGISTER_KERNEL_BUILDER(
841     Name("MultiDeviceIteratorFromStringHandle").Device(DEVICE_CPU),
842     MultiDeviceIteratorFromStringHandleOp);
843 
844 class DeleteMultiDeviceIteratorOp : public OpKernel {
845  public:
DeleteMultiDeviceIteratorOp(OpKernelConstruction * ctx)846   explicit DeleteMultiDeviceIteratorOp(OpKernelConstruction* ctx)
847       : OpKernel(ctx) {}
848 
Compute(OpKernelContext * ctx)849   void Compute(OpKernelContext* ctx) override {
850     ResourceHandle handle = ctx->input(0).flat<ResourceHandle>()(0);
851     // The iterator resource is guaranteed to
852     // exist because the variant tensor wrapping the deleter is provided as an
853     // unused input to this op, which guarantees that it has not run yet.
854     OP_REQUIRES_OK(ctx, DeleteResource(ctx, handle));
855   }
856 };
857 
858 REGISTER_KERNEL_BUILDER(Name("DeleteMultiDeviceIterator").Device(DEVICE_CPU),
859                         DeleteMultiDeviceIteratorOp);
860 // Since this op takes in Iterator handles as (unused) inputs, we don't want
861 // to constrain the iterator location to CPU only. Therefore, we exempt the
862 // colocation restriction for this op allowing the iterators to be placed on
863 // other devices.
864 REGISTER_INPUT_COLOCATION_EXEMPTION("DeleteMultiDeviceIterator");
865 
866 }  // namespace
867 }  // namespace data
868 }  // namespace tensorflow
869