1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/collective_rma_distributed.h"
17 
18 #include "google/protobuf/any.pb.h"
19 #include "tensorflow/core/common_runtime/device_mgr.h"
20 #include "tensorflow/core/common_runtime/dma_helper.h"
21 #include "tensorflow/core/common_runtime/process_util.h"
22 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
23 #include "tensorflow/core/distributed_runtime/test_utils.h"
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/cancellation.h"
26 #include "tensorflow/core/framework/device_attributes.pb.h"
27 #include "tensorflow/core/lib/core/notification.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/lib/random/random.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 #include "tensorflow/core/platform/errors.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/mem.h"
34 #include "tensorflow/core/platform/test.h"
35 #include "tensorflow/core/protobuf/transport_options.pb.h"
36 #include "tensorflow/core/protobuf/worker.pb.h"
37 
38 // The only interesting method on CollectiveRemoteAccessDistributed
39 // that's not on CollectiveRemoteAccessLocal is RecvFromPeer which
40 // issues a RecvBufAsync call against a WorkerInterface.  That's all
41 // that's tested here.  Note that RecvFromPeer can do a
42 // DeviceResolverInterface::GetDeviceLocalityAsync call in preparation
43 // for the RecvBufAsync.
44 
45 namespace tensorflow {
46 namespace {
47 
48 class FakeAllocator : public Allocator {
49  public:
Name()50   string Name() override { return "fake"; }
AllocateRaw(size_t alignment,size_t num_bytes)51   void* AllocateRaw(size_t alignment, size_t num_bytes) override {
52     return port::AlignedMalloc(num_bytes, alignment);
53   }
DeallocateRaw(void * ptr)54   void DeallocateRaw(void* ptr) override { return port::AlignedFree(ptr); }
55 };
56 
NewDevice(const string & type,const string & name,Allocator * allocator)57 static std::unique_ptr<Device> NewDevice(const string& type, const string& name,
58                                          Allocator* allocator) {
59   class FakeDevice : public Device {
60    public:
61     explicit FakeDevice(const DeviceAttributes& attr, Allocator* allocator)
62         : Device(nullptr, attr), allocator_(allocator) {}
63     Status Sync() override { return OkStatus(); }
64     Allocator* GetAllocator(AllocatorAttributes) override { return allocator_; }
65 
66    private:
67     Allocator* const allocator_;
68   };
69   DeviceAttributes attr;
70   attr.set_name(name);
71   attr.set_device_type(type);
72   attr.mutable_locality()->set_numa_node(3);  // a non-default value
73   attr.set_incarnation(random::New64());
74   return std::make_unique<FakeDevice>(attr, allocator);
75 }
76 
77 static int64_t kStepId = 123;
78 
79 class FakeWorker : public TestWorkerInterface {
80  public:
FakeWorker(const string & name,DeviceMgr * dev_mgr,DeviceResolverDistributed * dres,bool is_failed,bool set_tensor_in_extra)81   FakeWorker(const string& name, DeviceMgr* dev_mgr,
82              DeviceResolverDistributed* dres, bool is_failed,
83              bool set_tensor_in_extra)
84       : name_(name),
85         device_mgr_(dev_mgr),
86         device_resolver_(dres),
87         buf_rendezvous_(kStepId, dev_mgr),
88         is_failed_(is_failed),
89         set_tensor_in_extra_(set_tensor_in_extra) {}
90 
91   // Direct access to a BufRendezvous that holds whatever the remote
92   // worker is supposed to have.
buf_rendezvous()93   BufRendezvous* buf_rendezvous() { return &buf_rendezvous_; }
94 
GetStatusAsync(CallOptions * opts,const GetStatusRequest * request,GetStatusResponse * response,bool fail_fast,StatusCallback done)95   void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
96                       GetStatusResponse* response, bool fail_fast,
97                       StatusCallback done) override {
98     if (is_failed_) {
99       done(errors::Unavailable("peer down"));
100       return;
101     }
102     std::vector<DeviceAttributes> dev_attr;
103     device_mgr_->ListDeviceAttributes(&dev_attr);
104     for (const auto& da : dev_attr) {
105       *response->add_device_attributes() = da;
106     }
107     done(OkStatus());
108   }
109 
RecvBufAsync(CallOptions * opts,const RecvBufRequest * request,RecvBufResponse * response,StatusCallback done)110   void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
111                     RecvBufResponse* response, StatusCallback done) override {
112     if (is_failed_) {
113       done(errors::Unavailable("peer down"));
114       return;
115     }
116     opts->SetCancelCallback([this]() {
117       // Within this test the call is satisfied by a process-local
118       // BufRendezvous table. In real application the BufRendezvous
119       // would be on the other side of a network hop, so call
120       // BufRendezvous::StartAbort() from a separate thread to be
121       // more consistent with that situation and avoid mutex deadlock.
122       SchedClosure([this]() {
123         Env::Default()->SleepForMicroseconds(100);
124         buf_rendezvous_.StartAbort(errors::Internal("Cancelled"));
125       });
126     });
127     VLOG(2) << "ConsumeBuf key=" << request->buf_rendezvous_key()
128             << " src_device=" << request->src_device()
129             << " src_incarnation=" << request->src_incarnation();
130     buf_rendezvous_.ConsumeBuf(
131         request->buf_rendezvous_key(), request->src_device(),
132         request->src_incarnation(),
133         [this, opts, request, response, done](const Status& status,
134                                               BufRendezvous::Hook* h) {
135           Status s = status;
136           if (s.ok()) {
137             opts->ClearCancelCallback();
138             int64_t num_bytes = h->prod_value->TotalBytes();
139 
140             if (set_tensor_in_extra_) {
141               // Since this is not really RDMA into pre-allocated memory send
142               // the bytes in the response.
143               RecvBufRespExtra extra;
144               extra.add_tensor_content(string(
145                   reinterpret_cast<const char*>(DMAHelper::base(h->prod_value)),
146                   num_bytes));
147               response->mutable_transport_options()->PackFrom(extra);
148             } else {
149               if (request->num_bytes() != num_bytes) {
150                 s = errors::Internal("Tensor Size Mismatch.");
151               } else {
152                 memcpy(reinterpret_cast<void*>(request->buf_ptr()),
153                        DMAHelper::base(h->prod_value), num_bytes);
154               }
155             }
156           }
157           done(s);
158           if (h) BufRendezvous::DoneWithHook(h);
159         },
160         nullptr /*cancellation_manager*/);
161   }
162 
163  private:
164   string name_;
165   DeviceMgr* device_mgr_;
166   DeviceResolverDistributed* device_resolver_;
167   BufRendezvous buf_rendezvous_;
168   bool is_failed_;
169   const bool set_tensor_in_extra_;
170 };
171 
172 class FakeCache : public TestWorkerCache {
173  public:
174   // Override the Locality methods to actually pass through to the
175   // worker.
GetDeviceLocalityNonBlocking(const string & device,DeviceLocality * locality)176   bool GetDeviceLocalityNonBlocking(const string& device,
177                                     DeviceLocality* locality) override {
178     return false;
179   }
180 
GetDeviceLocalityAsync(const string & device,DeviceLocality * locality,StatusCallback done)181   void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
182                               StatusCallback done) override {
183     string task_name;
184     string dev_part;
185     if (!DeviceNameUtils::SplitDeviceName(device, &task_name, &dev_part)) {
186       done(errors::Internal("failed to parse device name"));
187       return;
188     }
189     auto it = workers_.find(task_name);
190     if (it == workers_.end()) {
191       done(errors::Internal("failed to find worker ", task_name));
192       return;
193     }
194     WorkerInterface* wi = it->second;
195     GetStatusRequest req;
196     GetStatusResponse resp;
197     Status status = wi->GetStatus(&req, &resp);
198     if (!status.ok()) {
199       done(status);
200       return;
201     }
202     for (const auto& it : resp.device_attributes()) {
203       if (it.name() == device) {
204         *locality = it.locality();
205         done(OkStatus());
206         return;
207       }
208     }
209     done(errors::Internal("device not found: ", device));
210   }
211 };
212 
213 enum TEST_PARAM_DEVICE_TYPE {
214   TEST_PARAM_DEVICE_TYPE_CPU = 0,
215   TEST_PARAM_DEVICE_TYPE_GPU,
216 };
217 
218 enum TEST_PARAM_TENSOR_LOC {
219   TEST_PARAM_TENSOR_LOC_AT_BUF_PTR = 0,
220   TEST_PARAM_TENSOR_LOC_IN_EXTRA,
221 };
222 
223 class CollRMADistTest
224     : public ::testing::TestWithParam<
225           std::tuple<TEST_PARAM_DEVICE_TYPE, TEST_PARAM_TENSOR_LOC>> {
226  protected:
CollRMADistTest()227   CollRMADistTest()
228       : work_queue_(
229             std::make_shared<UnboundedWorkQueue>(Env::Default(), "test")) {}
230 
~CollRMADistTest()231   ~CollRMADistTest() override {
232     for (DeviceMgr* dm : device_mgrs_) {
233       delete dm;
234     }
235     for (auto it : dev_resolvers_) {
236       delete it.second;
237     }
238     for (FakeWorker* w : workers_) {
239       delete w;
240     }
241   }
242 
SetUp()243   void SetUp() override {
244     const int num_workers = 2;
245     const int num_devices = 1;
246     string device_type = "CPU";
247     string dev0_worker_name;
248     for (int w = 0; w < num_workers; ++w) {
249       string name = strings::StrCat("/job:worker/replica:0/task:", w);
250       if (w == 0) {
251         dev0_worker_name = name;
252       }
253       DefineWorker(name, device_type, num_devices);
254     }
255     // All tests simulate requests from worker 0 to worker 1.
256     rma_.reset(new CollectiveRemoteAccessDistributed(
257         device_mgrs_[0], dev_resolvers_[dev0_worker_name], work_queue_, &wc_,
258         kStepId, "/job:worker/replica:0/task:0"));
259 
260     const int kNumElts = 8;
261     expected_value_ = Tensor(DT_FLOAT, {kNumElts});
262     to_tensor_ = Tensor(DT_FLOAT, {kNumElts});
263     large_response_ = Tensor(DT_FLOAT, {2 * kNumElts});
264     auto exp_alias = expected_value_.flat<float>();
265     auto to_alias = to_tensor_.flat<float>();
266     auto large_response_alias = large_response_.flat<float>();
267     for (int i = 0; i < kNumElts; ++i) {
268       exp_alias(i) = i;
269       to_alias(i) = -1;
270     }
271     for (int i = 0; i < 2 * kNumElts; ++i) {
272       large_response_alias(i) = -2;
273     }
274   }
275 
276   // Populates all device resolvers with device attributes of the cluster. This
277   // should be called in the beginning of all tests unless you would like to
278   // simulate a situation that is before parameter resolution.
ResolveDeviceAttributes()279   void ResolveDeviceAttributes() {
280     for (auto& dev_resolver_item : dev_resolvers_) {
281       DeviceResolverDistributed* dev_resolver = dev_resolver_item.second;
282       for (const auto& item : dev_by_task_) {
283         TF_CHECK_OK(dev_resolver->UpdateDeviceAttributes(item.second));
284       }
285     }
286   }
287 
DefineWorker(const string & worker_name,const string & device_type,int num_devices,bool is_failed=false)288   void DefineWorker(const string& worker_name, const string& device_type,
289                     int num_devices, bool is_failed = false) {
290     std::vector<std::unique_ptr<Device>> devices;
291     for (int i = 0; i < num_devices; ++i) {
292       devices.push_back(NewDevice(
293           device_type,
294           strings::StrCat(worker_name, "/device:", device_type, ":", i),
295           &fake_allocator_));
296     }
297     DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices));
298     device_mgrs_.push_back(dev_mgr);
299     std::vector<DeviceAttributes>* dv = &dev_by_task_[worker_name];
300     dv->clear();
301     for (auto d : dev_mgr->ListDevices()) {
302       dv->push_back(d->attributes());
303     }
304     DeviceResolverDistributed* dev_res = new DeviceResolverDistributed(dev_mgr);
305     dev_resolvers_[worker_name] = dev_res;
306     FakeWorker* fw =
307         new FakeWorker(worker_name, dev_mgr, dev_res, is_failed,
308                        /*set_tensor_in_extra=*/
309                        std::get<TEST_PARAM_TENSOR_LOC>(GetParam()) ==
310                            TEST_PARAM_TENSOR_LOC_IN_EXTRA);
311 
312     workers_.push_back(fw);
313     wc_.AddWorker(worker_name, fw);
314   }
315 
RestartWorker(const string & worker_name,const string & device_type,int num_devices,bool is_failed=false)316   void RestartWorker(const string& worker_name, const string& device_type,
317                      int num_devices, bool is_failed = false) {
318     auto it = dev_resolvers_.find(worker_name);
319     if (it != dev_resolvers_.end()) {
320       delete it->second;
321       dev_resolvers_.erase(it);
322     }
323     // After restarting a worker, the other workers already have the device
324     // attributes of the old worker. We don't broadcast device attributes of the
325     // new worker to mimic the real world.
326     DefineWorker(worker_name, device_type, num_devices, is_failed);
327   }
328 
ValidateResultTensor()329   void ValidateResultTensor() {
330     ASSERT_EQ(expected_value_.NumElements(), to_tensor_.NumElements());
331     for (int i = 0; i < to_tensor_.NumElements(); ++i) {
332       EXPECT_FLOAT_EQ(expected_value_.flat<float>()(i),
333                       to_tensor_.flat<float>()(i));
334     }
335   }
336 
ValidateResultTensorUnchanged()337   void ValidateResultTensorUnchanged() {
338     for (int i = 0; i < to_tensor_.NumElements(); ++i) {
339       EXPECT_FLOAT_EQ(-1, to_tensor_.flat<float>()(i));
340     }
341   }
342 
MaybeSetGPUDevice(Device * dst_device)343   void MaybeSetGPUDevice(Device* dst_device) {
344     if (std::get<TEST_PARAM_DEVICE_TYPE>(GetParam()) ==
345         TEST_PARAM_DEVICE_TYPE_GPU) {
346       dst_device->set_tensorflow_accelerator_device_info(
347           &accelerator_device_info_);
348     }
349   }
350 
351   FakeCache wc_;
352   CancellationManager cm_;
353   std::vector<DeviceMgr*> device_mgrs_;
354   std::unordered_map<string, DeviceResolverDistributed*> dev_resolvers_;
355   std::unordered_map<string, std::vector<DeviceAttributes>> dev_by_task_;
356   std::shared_ptr<UnboundedWorkQueue> work_queue_;
357   std::vector<FakeWorker*> workers_;
358   std::unique_ptr<CollectiveRemoteAccessDistributed> rma_;
359   mutex mu_;
360   int num_done_ TF_GUARDED_BY(mu_);
361   condition_variable done_;
362   CallOptions opts_;
363   DeviceLocality device_locality_;
364   AllocatorAttributes alloc_attr_;
365   FakeAllocator fake_allocator_;
366   DeviceBase::AcceleratorDeviceInfo accelerator_device_info_;
367   Tensor expected_value_;
368   Tensor large_response_;
369   Tensor to_tensor_;
370 };
371 
TEST_P(CollRMADistTest,ProdFirstOK)372 TEST_P(CollRMADistTest, ProdFirstOK) {
373   ResolveDeviceAttributes();
374   Notification consumer_note;
375   Notification producer_note;
376   Status consumer_status;
377   Status producer_status;
378   FakeWorker* wi = workers_[1];
379   const string kBufKey = "fake_buf_key";
380   wi->buf_rendezvous()->ProvideBuf(
381       kBufKey, nullptr /*device*/, nullptr /*dev_ctx*/, &expected_value_,
382       AllocatorAttributes(),
383       [&producer_note, &producer_status](const Status& s) {
384         producer_status.Update(s);
385         producer_note.Notify();
386       },
387       nullptr /*cancellation_manager*/);
388   Device* dst_device = nullptr;
389   string dev_name = "CPU:0";
390   TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
391   DeviceContext* to_device_ctx = nullptr;
392   MaybeSetGPUDevice(dst_device);
393   rma_->RecvFromPeer(
394       "/job:worker/replica:0/task:1/device:" + dev_name,  // peer_dev
395       "/job:worker/replica:0/task:1",                     // peer_task
396       false,                                              // peer_is_local
397       kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
398       device_locality_, 0 /*dev_to_dev_stream_index*/,
399       nullptr /*cancellation_manager*/,
400       [&consumer_status, &consumer_note](const Status& s) {
401         consumer_status = s;
402         consumer_note.Notify();
403       });
404   consumer_note.WaitForNotification();
405   TF_EXPECT_OK(consumer_status);
406   producer_note.WaitForNotification();
407   TF_EXPECT_OK(producer_status);
408   ValidateResultTensor();
409 }
410 
TEST_P(CollRMADistTest,ConsFirstOK)411 TEST_P(CollRMADistTest, ConsFirstOK) {
412   ResolveDeviceAttributes();
413   Notification consumer_note;
414   Notification producer_note;
415   Status consumer_status;
416   Status producer_status;
417   FakeWorker* wi = workers_[1];
418   const string kBufKey = "fake_buf_key";
419   Device* dst_device = nullptr;
420   string dev_name = "CPU:0";
421   TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
422   MaybeSetGPUDevice(dst_device);
423   DeviceContext* to_device_ctx = nullptr;
424   rma_->RecvFromPeer(
425       "/job:worker/replica:0/task:1/device:" + dev_name,  // peer_dev
426       "/job:worker/replica:0/task:1",                     // peer_task
427       false,                                              // peer_is_local
428       kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
429       device_locality_, 0 /*dev_to_dev_stream_index*/,
430       nullptr /*cancellation_manager*/,
431       [&consumer_status, &consumer_note](const Status& s) {
432         consumer_status = s;
433         consumer_note.Notify();
434       });
435   wi->buf_rendezvous()->ProvideBuf(
436       kBufKey, nullptr /*device*/, nullptr /*dev_ctx*/, &expected_value_,
437       AllocatorAttributes(),
438       [&producer_note, &producer_status](const Status& s) {
439         producer_status.Update(s);
440         producer_note.Notify();
441       },
442       nullptr /*cancellation_manager*/);
443   consumer_note.WaitForNotification();
444   TF_EXPECT_OK(consumer_status);
445   producer_note.WaitForNotification();
446   TF_EXPECT_OK(producer_status);
447   ValidateResultTensor();
448 }
449 
TEST_P(CollRMADistTest,ConsFirstAbort)450 TEST_P(CollRMADistTest, ConsFirstAbort) {
451   ResolveDeviceAttributes();
452   Notification consumer_note;
453   Status consumer_status;
454   const string kBufKey = "fake_buf_key";
455   Device* dst_device = nullptr;
456   string dev_name = "CPU:0";
457   TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
458   MaybeSetGPUDevice(dst_device);
459   DeviceContext* to_device_ctx = nullptr;
460   rma_->RecvFromPeer(
461       "/job:worker/replica:0/task:1/device:" + dev_name,  // peer_dev
462       "/job:worker/replica:0/task:1",                     // peer_task
463       false,                                              // peer_is_local
464       kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
465       device_locality_, 0 /*dev_to_dev_stream_index*/,
466       nullptr /*cancellation_manager*/,
467       [&consumer_status, &consumer_note](const Status& s) {
468         consumer_status = s;
469         consumer_note.Notify();
470       });
471   rma_->StartAbort(errors::Internal("Deliberate Failure"));
472   consumer_note.WaitForNotification();
473   EXPECT_EQ(consumer_status.error_message(), "Cancelled");
474 }
475 
TEST_P(CollRMADistTest,ResponseTooLarge)476 TEST_P(CollRMADistTest, ResponseTooLarge) {
477   ResolveDeviceAttributes();
478   Notification consumer_note;
479   Notification producer_note;
480   Status consumer_status;
481   Status producer_status;
482   FakeWorker* wi = workers_[1];
483   const string kBufKey = "fake_buf_key";
484   wi->buf_rendezvous()->ProvideBuf(
485       kBufKey, nullptr /*device*/, nullptr /*dev_ctx*/, &large_response_,
486       AllocatorAttributes(),
487       [&producer_note, &producer_status](const Status& s) {
488         producer_status.Update(s);
489         producer_note.Notify();
490       },
491       nullptr /*cancellation_manager*/);
492   Device* dst_device = nullptr;
493   string dev_name = "CPU:0";
494   TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
495   DeviceContext* to_device_ctx = nullptr;
496   MaybeSetGPUDevice(dst_device);
497   rma_->RecvFromPeer(
498       "/job:worker/replica:0/task:1/device:" + dev_name,  // peer_dev
499       "/job:worker/replica:0/task:1",                     // peer_task
500       false,                                              // peer_is_local
501       kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
502       device_locality_, 0 /*dev_to_dev_stream_index*/,
503       nullptr /*cancellation_manager*/,
504       [&consumer_status, &consumer_note](const Status& s) {
505         consumer_status = s;
506         consumer_note.Notify();
507       });
508   consumer_note.WaitForNotification();
509   EXPECT_THAT(consumer_status.error_message(),
510               ::testing::HasSubstr("Tensor Size Mismatch"));
511   producer_note.WaitForNotification();
512   TF_EXPECT_OK(producer_status);
513   ValidateResultTensorUnchanged();
514 }
515 
TEST_P(CollRMADistTest,WorkerRestart)516 TEST_P(CollRMADistTest, WorkerRestart) {
517   ResolveDeviceAttributes();
518   Notification consumer_note;
519   Notification producer_note;
520   Status consumer_status;
521   Status producer_status;
522   FakeWorker* wi = workers_[1];
523   const string buf_key = "fake_buf_key";
524   Device* dst_device = nullptr;
525   string dev_name = "CPU:0";
526   TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
527   MaybeSetGPUDevice(dst_device);
528   DeviceContext* to_device_ctx = nullptr;
529   rma_->RecvFromPeer(
530       "/job:worker/replica:0/task:1/device:" + dev_name,  // peer_dev
531       "/job:worker/replica:0/task:1",                     // peer_task
532       false,                                              // peer_is_local
533       buf_key, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
534       device_locality_, 0 /*dev_to_dev_stream_index*/,
535       nullptr /*cancellation_manager*/,
536       [&consumer_status, &consumer_note](const Status& s) {
537         consumer_status = s;
538         consumer_note.Notify();
539       });
540   wi->buf_rendezvous()->ProvideBuf(
541       buf_key, nullptr /*device*/, nullptr /*dev_ctx*/, &expected_value_,
542       AllocatorAttributes(),
543       [&producer_note, &producer_status](const Status& s) {
544         producer_status.Update(s);
545         producer_note.Notify();
546       },
547       nullptr /*cancellation_manager*/);
548   consumer_note.WaitForNotification();
549   TF_EXPECT_OK(consumer_status);
550   producer_note.WaitForNotification();
551   TF_EXPECT_OK(producer_status);
552   ValidateResultTensor();
553 
554   // Restart task 1 and check that recv from task 1 to task 0 fails.
555   RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1);
556   Notification post_restart_note;
557   rma_->RecvFromPeer(
558       "/job:worker/replica:0/task:1/device:" + dev_name,  // peer_dev
559       "/job:worker/replica:0/task:1",                     // peer_task
560       false,                                              // peer_is_local
561       buf_key, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
562       device_locality_, 0 /*dev_to_dev_stream_index*/,
563       nullptr /*cancellation_manager*/,
564       [&consumer_status, &post_restart_note](const Status& s) {
565         consumer_status = s;
566         post_restart_note.Notify();
567       });
568   post_restart_note.WaitForNotification();
569   EXPECT_TRUE(errors::IsFailedPrecondition(consumer_status));
570 }
571 
TEST_P(CollRMADistTest,CheckHealthOKWithCachedAttr)572 TEST_P(CollRMADistTest, CheckHealthOKWithCachedAttr) {
573   ResolveDeviceAttributes();
574   Status check_health_status;
575   Notification check_health_done;
576   rma_->CheckPeerHealth(
577       "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
578       [&check_health_status, &check_health_done](const Status s) {
579         check_health_status = s;
580         check_health_done.Notify();
581       });
582   check_health_done.WaitForNotification();
583   TF_EXPECT_OK(check_health_status);
584 }
585 
TEST_P(CollRMADistTest,CheckHealthOKWithoutCachedAttr)586 TEST_P(CollRMADistTest, CheckHealthOKWithoutCachedAttr) {
587   Status check_health_status;
588   Notification check_health_done;
589   rma_->CheckPeerHealth(
590       "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
591       [&check_health_status, &check_health_done](const Status s) {
592         check_health_status = s;
593         check_health_done.Notify();
594       });
595   check_health_done.WaitForNotification();
596   EXPECT_TRUE(check_health_status.ok());
597 }
598 
TEST_P(CollRMADistTest,CheckHealthRestarted)599 TEST_P(CollRMADistTest, CheckHealthRestarted) {
600   ResolveDeviceAttributes();
601   RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1);
602 
603   Status check_health_status;
604   Notification check_health_done;
605   rma_->CheckPeerHealth(
606       "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
607       [&check_health_status, &check_health_done](const Status s) {
608         check_health_status = s;
609         check_health_done.Notify();
610       });
611   check_health_done.WaitForNotification();
612   EXPECT_TRUE(errors::IsFailedPrecondition(check_health_status));
613 }
614 
TEST_P(CollRMADistTest,CheckHealthFailedPeer)615 TEST_P(CollRMADistTest, CheckHealthFailedPeer) {
616   ResolveDeviceAttributes();
617   RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1,
618                 /*is_failed*/ true);
619 
620   Status check_health_status;
621   Notification check_health_done;
622   rma_->CheckPeerHealth(
623       "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
624       [&check_health_status, &check_health_done](const Status s) {
625         check_health_status = s;
626         check_health_done.Notify();
627       });
628   check_health_done.WaitForNotification();
629   EXPECT_TRUE(errors::IsUnavailable(check_health_status));
630 }
631 
TEST_P(CollRMADistTest,CheckHealthRestartedWithDifferentDevices)632 TEST_P(CollRMADistTest, CheckHealthRestartedWithDifferentDevices) {
633   ResolveDeviceAttributes();
634   RestartWorker("/job:worker/replica:0/task:1", "GPU", /*num_devices*/ 1);
635   Status check_health_status;
636   Notification check_health_done;
637   rma_->CheckPeerHealth(
638       "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
639       [&check_health_status, &check_health_done](const Status s) {
640         check_health_status = s;
641         check_health_done.Notify();
642       });
643   check_health_done.WaitForNotification();
644   EXPECT_TRUE(errors::IsFailedPrecondition(check_health_status));
645 }
646 
647 INSTANTIATE_TEST_SUITE_P(
648     TensorInBufPtrOrExtra, CollRMADistTest,
649     ::testing::Combine(::testing::Values(TEST_PARAM_TENSOR_LOC_AT_BUF_PTR,
650                                          TEST_PARAM_TENSOR_LOC_IN_EXTRA),
651                        ::testing::Values(TEST_PARAM_DEVICE_TYPE_CPU,
652                                          TEST_PARAM_DEVICE_TYPE_GPU)));
653 
654 }  // namespace
655 }  // namespace tensorflow
656