1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/collective_rma_distributed.h"
17
18 #include "google/protobuf/any.pb.h"
19 #include "tensorflow/core/common_runtime/device_mgr.h"
20 #include "tensorflow/core/common_runtime/dma_helper.h"
21 #include "tensorflow/core/common_runtime/process_util.h"
22 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
23 #include "tensorflow/core/distributed_runtime/test_utils.h"
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/cancellation.h"
26 #include "tensorflow/core/framework/device_attributes.pb.h"
27 #include "tensorflow/core/lib/core/notification.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/lib/random/random.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 #include "tensorflow/core/platform/errors.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/mem.h"
34 #include "tensorflow/core/platform/test.h"
35 #include "tensorflow/core/protobuf/transport_options.pb.h"
36 #include "tensorflow/core/protobuf/worker.pb.h"
37
38 // The only interesting method on CollectiveRemoteAccessDistributed
39 // that's not on CollectiveRemoteAccessLocal is RecvFromPeer which
40 // issues a RecvBufAsync call against a WorkerInterface. That's all
41 // that's tested here. Note that RecvFromPeer can do a
42 // DeviceResolverInterface::GetDeviceLocalityAsync call in preparation
43 // for the RecvBufAsync.
44
45 namespace tensorflow {
46 namespace {
47
48 class FakeAllocator : public Allocator {
49 public:
Name()50 string Name() override { return "fake"; }
AllocateRaw(size_t alignment,size_t num_bytes)51 void* AllocateRaw(size_t alignment, size_t num_bytes) override {
52 return port::AlignedMalloc(num_bytes, alignment);
53 }
DeallocateRaw(void * ptr)54 void DeallocateRaw(void* ptr) override { return port::AlignedFree(ptr); }
55 };
56
NewDevice(const string & type,const string & name,Allocator * allocator)57 static std::unique_ptr<Device> NewDevice(const string& type, const string& name,
58 Allocator* allocator) {
59 class FakeDevice : public Device {
60 public:
61 explicit FakeDevice(const DeviceAttributes& attr, Allocator* allocator)
62 : Device(nullptr, attr), allocator_(allocator) {}
63 Status Sync() override { return OkStatus(); }
64 Allocator* GetAllocator(AllocatorAttributes) override { return allocator_; }
65
66 private:
67 Allocator* const allocator_;
68 };
69 DeviceAttributes attr;
70 attr.set_name(name);
71 attr.set_device_type(type);
72 attr.mutable_locality()->set_numa_node(3); // a non-default value
73 attr.set_incarnation(random::New64());
74 return std::make_unique<FakeDevice>(attr, allocator);
75 }
76
77 static int64_t kStepId = 123;
78
79 class FakeWorker : public TestWorkerInterface {
80 public:
FakeWorker(const string & name,DeviceMgr * dev_mgr,DeviceResolverDistributed * dres,bool is_failed,bool set_tensor_in_extra)81 FakeWorker(const string& name, DeviceMgr* dev_mgr,
82 DeviceResolverDistributed* dres, bool is_failed,
83 bool set_tensor_in_extra)
84 : name_(name),
85 device_mgr_(dev_mgr),
86 device_resolver_(dres),
87 buf_rendezvous_(kStepId, dev_mgr),
88 is_failed_(is_failed),
89 set_tensor_in_extra_(set_tensor_in_extra) {}
90
91 // Direct access to a BufRendezvous that holds whatever the remote
92 // worker is supposed to have.
buf_rendezvous()93 BufRendezvous* buf_rendezvous() { return &buf_rendezvous_; }
94
GetStatusAsync(CallOptions * opts,const GetStatusRequest * request,GetStatusResponse * response,bool fail_fast,StatusCallback done)95 void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
96 GetStatusResponse* response, bool fail_fast,
97 StatusCallback done) override {
98 if (is_failed_) {
99 done(errors::Unavailable("peer down"));
100 return;
101 }
102 std::vector<DeviceAttributes> dev_attr;
103 device_mgr_->ListDeviceAttributes(&dev_attr);
104 for (const auto& da : dev_attr) {
105 *response->add_device_attributes() = da;
106 }
107 done(OkStatus());
108 }
109
RecvBufAsync(CallOptions * opts,const RecvBufRequest * request,RecvBufResponse * response,StatusCallback done)110 void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
111 RecvBufResponse* response, StatusCallback done) override {
112 if (is_failed_) {
113 done(errors::Unavailable("peer down"));
114 return;
115 }
116 opts->SetCancelCallback([this]() {
117 // Within this test the call is satisfied by a process-local
118 // BufRendezvous table. In real application the BufRendezvous
119 // would be on the other side of a network hop, so call
120 // BufRendezvous::StartAbort() from a separate thread to be
121 // more consistent with that situation and avoid mutex deadlock.
122 SchedClosure([this]() {
123 Env::Default()->SleepForMicroseconds(100);
124 buf_rendezvous_.StartAbort(errors::Internal("Cancelled"));
125 });
126 });
127 VLOG(2) << "ConsumeBuf key=" << request->buf_rendezvous_key()
128 << " src_device=" << request->src_device()
129 << " src_incarnation=" << request->src_incarnation();
130 buf_rendezvous_.ConsumeBuf(
131 request->buf_rendezvous_key(), request->src_device(),
132 request->src_incarnation(),
133 [this, opts, request, response, done](const Status& status,
134 BufRendezvous::Hook* h) {
135 Status s = status;
136 if (s.ok()) {
137 opts->ClearCancelCallback();
138 int64_t num_bytes = h->prod_value->TotalBytes();
139
140 if (set_tensor_in_extra_) {
141 // Since this is not really RDMA into pre-allocated memory send
142 // the bytes in the response.
143 RecvBufRespExtra extra;
144 extra.add_tensor_content(string(
145 reinterpret_cast<const char*>(DMAHelper::base(h->prod_value)),
146 num_bytes));
147 response->mutable_transport_options()->PackFrom(extra);
148 } else {
149 if (request->num_bytes() != num_bytes) {
150 s = errors::Internal("Tensor Size Mismatch.");
151 } else {
152 memcpy(reinterpret_cast<void*>(request->buf_ptr()),
153 DMAHelper::base(h->prod_value), num_bytes);
154 }
155 }
156 }
157 done(s);
158 if (h) BufRendezvous::DoneWithHook(h);
159 },
160 nullptr /*cancellation_manager*/);
161 }
162
163 private:
164 string name_;
165 DeviceMgr* device_mgr_;
166 DeviceResolverDistributed* device_resolver_;
167 BufRendezvous buf_rendezvous_;
168 bool is_failed_;
169 const bool set_tensor_in_extra_;
170 };
171
172 class FakeCache : public TestWorkerCache {
173 public:
174 // Override the Locality methods to actually pass through to the
175 // worker.
GetDeviceLocalityNonBlocking(const string & device,DeviceLocality * locality)176 bool GetDeviceLocalityNonBlocking(const string& device,
177 DeviceLocality* locality) override {
178 return false;
179 }
180
GetDeviceLocalityAsync(const string & device,DeviceLocality * locality,StatusCallback done)181 void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
182 StatusCallback done) override {
183 string task_name;
184 string dev_part;
185 if (!DeviceNameUtils::SplitDeviceName(device, &task_name, &dev_part)) {
186 done(errors::Internal("failed to parse device name"));
187 return;
188 }
189 auto it = workers_.find(task_name);
190 if (it == workers_.end()) {
191 done(errors::Internal("failed to find worker ", task_name));
192 return;
193 }
194 WorkerInterface* wi = it->second;
195 GetStatusRequest req;
196 GetStatusResponse resp;
197 Status status = wi->GetStatus(&req, &resp);
198 if (!status.ok()) {
199 done(status);
200 return;
201 }
202 for (const auto& it : resp.device_attributes()) {
203 if (it.name() == device) {
204 *locality = it.locality();
205 done(OkStatus());
206 return;
207 }
208 }
209 done(errors::Internal("device not found: ", device));
210 }
211 };
212
213 enum TEST_PARAM_DEVICE_TYPE {
214 TEST_PARAM_DEVICE_TYPE_CPU = 0,
215 TEST_PARAM_DEVICE_TYPE_GPU,
216 };
217
218 enum TEST_PARAM_TENSOR_LOC {
219 TEST_PARAM_TENSOR_LOC_AT_BUF_PTR = 0,
220 TEST_PARAM_TENSOR_LOC_IN_EXTRA,
221 };
222
223 class CollRMADistTest
224 : public ::testing::TestWithParam<
225 std::tuple<TEST_PARAM_DEVICE_TYPE, TEST_PARAM_TENSOR_LOC>> {
226 protected:
CollRMADistTest()227 CollRMADistTest()
228 : work_queue_(
229 std::make_shared<UnboundedWorkQueue>(Env::Default(), "test")) {}
230
~CollRMADistTest()231 ~CollRMADistTest() override {
232 for (DeviceMgr* dm : device_mgrs_) {
233 delete dm;
234 }
235 for (auto it : dev_resolvers_) {
236 delete it.second;
237 }
238 for (FakeWorker* w : workers_) {
239 delete w;
240 }
241 }
242
SetUp()243 void SetUp() override {
244 const int num_workers = 2;
245 const int num_devices = 1;
246 string device_type = "CPU";
247 string dev0_worker_name;
248 for (int w = 0; w < num_workers; ++w) {
249 string name = strings::StrCat("/job:worker/replica:0/task:", w);
250 if (w == 0) {
251 dev0_worker_name = name;
252 }
253 DefineWorker(name, device_type, num_devices);
254 }
255 // All tests simulate requests from worker 0 to worker 1.
256 rma_.reset(new CollectiveRemoteAccessDistributed(
257 device_mgrs_[0], dev_resolvers_[dev0_worker_name], work_queue_, &wc_,
258 kStepId, "/job:worker/replica:0/task:0"));
259
260 const int kNumElts = 8;
261 expected_value_ = Tensor(DT_FLOAT, {kNumElts});
262 to_tensor_ = Tensor(DT_FLOAT, {kNumElts});
263 large_response_ = Tensor(DT_FLOAT, {2 * kNumElts});
264 auto exp_alias = expected_value_.flat<float>();
265 auto to_alias = to_tensor_.flat<float>();
266 auto large_response_alias = large_response_.flat<float>();
267 for (int i = 0; i < kNumElts; ++i) {
268 exp_alias(i) = i;
269 to_alias(i) = -1;
270 }
271 for (int i = 0; i < 2 * kNumElts; ++i) {
272 large_response_alias(i) = -2;
273 }
274 }
275
276 // Populates all device resolvers with device attributes of the cluster. This
277 // should be called in the beginning of all tests unless you would like to
278 // simulate a situation that is before parameter resolution.
ResolveDeviceAttributes()279 void ResolveDeviceAttributes() {
280 for (auto& dev_resolver_item : dev_resolvers_) {
281 DeviceResolverDistributed* dev_resolver = dev_resolver_item.second;
282 for (const auto& item : dev_by_task_) {
283 TF_CHECK_OK(dev_resolver->UpdateDeviceAttributes(item.second));
284 }
285 }
286 }
287
DefineWorker(const string & worker_name,const string & device_type,int num_devices,bool is_failed=false)288 void DefineWorker(const string& worker_name, const string& device_type,
289 int num_devices, bool is_failed = false) {
290 std::vector<std::unique_ptr<Device>> devices;
291 for (int i = 0; i < num_devices; ++i) {
292 devices.push_back(NewDevice(
293 device_type,
294 strings::StrCat(worker_name, "/device:", device_type, ":", i),
295 &fake_allocator_));
296 }
297 DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices));
298 device_mgrs_.push_back(dev_mgr);
299 std::vector<DeviceAttributes>* dv = &dev_by_task_[worker_name];
300 dv->clear();
301 for (auto d : dev_mgr->ListDevices()) {
302 dv->push_back(d->attributes());
303 }
304 DeviceResolverDistributed* dev_res = new DeviceResolverDistributed(dev_mgr);
305 dev_resolvers_[worker_name] = dev_res;
306 FakeWorker* fw =
307 new FakeWorker(worker_name, dev_mgr, dev_res, is_failed,
308 /*set_tensor_in_extra=*/
309 std::get<TEST_PARAM_TENSOR_LOC>(GetParam()) ==
310 TEST_PARAM_TENSOR_LOC_IN_EXTRA);
311
312 workers_.push_back(fw);
313 wc_.AddWorker(worker_name, fw);
314 }
315
RestartWorker(const string & worker_name,const string & device_type,int num_devices,bool is_failed=false)316 void RestartWorker(const string& worker_name, const string& device_type,
317 int num_devices, bool is_failed = false) {
318 auto it = dev_resolvers_.find(worker_name);
319 if (it != dev_resolvers_.end()) {
320 delete it->second;
321 dev_resolvers_.erase(it);
322 }
323 // After restarting a worker, the other workers already have the device
324 // attributes of the old worker. We don't broadcast device attributes of the
325 // new worker to mimic the real world.
326 DefineWorker(worker_name, device_type, num_devices, is_failed);
327 }
328
ValidateResultTensor()329 void ValidateResultTensor() {
330 ASSERT_EQ(expected_value_.NumElements(), to_tensor_.NumElements());
331 for (int i = 0; i < to_tensor_.NumElements(); ++i) {
332 EXPECT_FLOAT_EQ(expected_value_.flat<float>()(i),
333 to_tensor_.flat<float>()(i));
334 }
335 }
336
ValidateResultTensorUnchanged()337 void ValidateResultTensorUnchanged() {
338 for (int i = 0; i < to_tensor_.NumElements(); ++i) {
339 EXPECT_FLOAT_EQ(-1, to_tensor_.flat<float>()(i));
340 }
341 }
342
MaybeSetGPUDevice(Device * dst_device)343 void MaybeSetGPUDevice(Device* dst_device) {
344 if (std::get<TEST_PARAM_DEVICE_TYPE>(GetParam()) ==
345 TEST_PARAM_DEVICE_TYPE_GPU) {
346 dst_device->set_tensorflow_accelerator_device_info(
347 &accelerator_device_info_);
348 }
349 }
350
351 FakeCache wc_;
352 CancellationManager cm_;
353 std::vector<DeviceMgr*> device_mgrs_;
354 std::unordered_map<string, DeviceResolverDistributed*> dev_resolvers_;
355 std::unordered_map<string, std::vector<DeviceAttributes>> dev_by_task_;
356 std::shared_ptr<UnboundedWorkQueue> work_queue_;
357 std::vector<FakeWorker*> workers_;
358 std::unique_ptr<CollectiveRemoteAccessDistributed> rma_;
359 mutex mu_;
360 int num_done_ TF_GUARDED_BY(mu_);
361 condition_variable done_;
362 CallOptions opts_;
363 DeviceLocality device_locality_;
364 AllocatorAttributes alloc_attr_;
365 FakeAllocator fake_allocator_;
366 DeviceBase::AcceleratorDeviceInfo accelerator_device_info_;
367 Tensor expected_value_;
368 Tensor large_response_;
369 Tensor to_tensor_;
370 };
371
TEST_P(CollRMADistTest,ProdFirstOK)372 TEST_P(CollRMADistTest, ProdFirstOK) {
373 ResolveDeviceAttributes();
374 Notification consumer_note;
375 Notification producer_note;
376 Status consumer_status;
377 Status producer_status;
378 FakeWorker* wi = workers_[1];
379 const string kBufKey = "fake_buf_key";
380 wi->buf_rendezvous()->ProvideBuf(
381 kBufKey, nullptr /*device*/, nullptr /*dev_ctx*/, &expected_value_,
382 AllocatorAttributes(),
383 [&producer_note, &producer_status](const Status& s) {
384 producer_status.Update(s);
385 producer_note.Notify();
386 },
387 nullptr /*cancellation_manager*/);
388 Device* dst_device = nullptr;
389 string dev_name = "CPU:0";
390 TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
391 DeviceContext* to_device_ctx = nullptr;
392 MaybeSetGPUDevice(dst_device);
393 rma_->RecvFromPeer(
394 "/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
395 "/job:worker/replica:0/task:1", // peer_task
396 false, // peer_is_local
397 kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
398 device_locality_, 0 /*dev_to_dev_stream_index*/,
399 nullptr /*cancellation_manager*/,
400 [&consumer_status, &consumer_note](const Status& s) {
401 consumer_status = s;
402 consumer_note.Notify();
403 });
404 consumer_note.WaitForNotification();
405 TF_EXPECT_OK(consumer_status);
406 producer_note.WaitForNotification();
407 TF_EXPECT_OK(producer_status);
408 ValidateResultTensor();
409 }
410
TEST_P(CollRMADistTest,ConsFirstOK)411 TEST_P(CollRMADistTest, ConsFirstOK) {
412 ResolveDeviceAttributes();
413 Notification consumer_note;
414 Notification producer_note;
415 Status consumer_status;
416 Status producer_status;
417 FakeWorker* wi = workers_[1];
418 const string kBufKey = "fake_buf_key";
419 Device* dst_device = nullptr;
420 string dev_name = "CPU:0";
421 TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
422 MaybeSetGPUDevice(dst_device);
423 DeviceContext* to_device_ctx = nullptr;
424 rma_->RecvFromPeer(
425 "/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
426 "/job:worker/replica:0/task:1", // peer_task
427 false, // peer_is_local
428 kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
429 device_locality_, 0 /*dev_to_dev_stream_index*/,
430 nullptr /*cancellation_manager*/,
431 [&consumer_status, &consumer_note](const Status& s) {
432 consumer_status = s;
433 consumer_note.Notify();
434 });
435 wi->buf_rendezvous()->ProvideBuf(
436 kBufKey, nullptr /*device*/, nullptr /*dev_ctx*/, &expected_value_,
437 AllocatorAttributes(),
438 [&producer_note, &producer_status](const Status& s) {
439 producer_status.Update(s);
440 producer_note.Notify();
441 },
442 nullptr /*cancellation_manager*/);
443 consumer_note.WaitForNotification();
444 TF_EXPECT_OK(consumer_status);
445 producer_note.WaitForNotification();
446 TF_EXPECT_OK(producer_status);
447 ValidateResultTensor();
448 }
449
TEST_P(CollRMADistTest,ConsFirstAbort)450 TEST_P(CollRMADistTest, ConsFirstAbort) {
451 ResolveDeviceAttributes();
452 Notification consumer_note;
453 Status consumer_status;
454 const string kBufKey = "fake_buf_key";
455 Device* dst_device = nullptr;
456 string dev_name = "CPU:0";
457 TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
458 MaybeSetGPUDevice(dst_device);
459 DeviceContext* to_device_ctx = nullptr;
460 rma_->RecvFromPeer(
461 "/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
462 "/job:worker/replica:0/task:1", // peer_task
463 false, // peer_is_local
464 kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
465 device_locality_, 0 /*dev_to_dev_stream_index*/,
466 nullptr /*cancellation_manager*/,
467 [&consumer_status, &consumer_note](const Status& s) {
468 consumer_status = s;
469 consumer_note.Notify();
470 });
471 rma_->StartAbort(errors::Internal("Deliberate Failure"));
472 consumer_note.WaitForNotification();
473 EXPECT_EQ(consumer_status.error_message(), "Cancelled");
474 }
475
TEST_P(CollRMADistTest,ResponseTooLarge)476 TEST_P(CollRMADistTest, ResponseTooLarge) {
477 ResolveDeviceAttributes();
478 Notification consumer_note;
479 Notification producer_note;
480 Status consumer_status;
481 Status producer_status;
482 FakeWorker* wi = workers_[1];
483 const string kBufKey = "fake_buf_key";
484 wi->buf_rendezvous()->ProvideBuf(
485 kBufKey, nullptr /*device*/, nullptr /*dev_ctx*/, &large_response_,
486 AllocatorAttributes(),
487 [&producer_note, &producer_status](const Status& s) {
488 producer_status.Update(s);
489 producer_note.Notify();
490 },
491 nullptr /*cancellation_manager*/);
492 Device* dst_device = nullptr;
493 string dev_name = "CPU:0";
494 TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
495 DeviceContext* to_device_ctx = nullptr;
496 MaybeSetGPUDevice(dst_device);
497 rma_->RecvFromPeer(
498 "/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
499 "/job:worker/replica:0/task:1", // peer_task
500 false, // peer_is_local
501 kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
502 device_locality_, 0 /*dev_to_dev_stream_index*/,
503 nullptr /*cancellation_manager*/,
504 [&consumer_status, &consumer_note](const Status& s) {
505 consumer_status = s;
506 consumer_note.Notify();
507 });
508 consumer_note.WaitForNotification();
509 EXPECT_THAT(consumer_status.error_message(),
510 ::testing::HasSubstr("Tensor Size Mismatch"));
511 producer_note.WaitForNotification();
512 TF_EXPECT_OK(producer_status);
513 ValidateResultTensorUnchanged();
514 }
515
TEST_P(CollRMADistTest,WorkerRestart)516 TEST_P(CollRMADistTest, WorkerRestart) {
517 ResolveDeviceAttributes();
518 Notification consumer_note;
519 Notification producer_note;
520 Status consumer_status;
521 Status producer_status;
522 FakeWorker* wi = workers_[1];
523 const string buf_key = "fake_buf_key";
524 Device* dst_device = nullptr;
525 string dev_name = "CPU:0";
526 TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
527 MaybeSetGPUDevice(dst_device);
528 DeviceContext* to_device_ctx = nullptr;
529 rma_->RecvFromPeer(
530 "/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
531 "/job:worker/replica:0/task:1", // peer_task
532 false, // peer_is_local
533 buf_key, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
534 device_locality_, 0 /*dev_to_dev_stream_index*/,
535 nullptr /*cancellation_manager*/,
536 [&consumer_status, &consumer_note](const Status& s) {
537 consumer_status = s;
538 consumer_note.Notify();
539 });
540 wi->buf_rendezvous()->ProvideBuf(
541 buf_key, nullptr /*device*/, nullptr /*dev_ctx*/, &expected_value_,
542 AllocatorAttributes(),
543 [&producer_note, &producer_status](const Status& s) {
544 producer_status.Update(s);
545 producer_note.Notify();
546 },
547 nullptr /*cancellation_manager*/);
548 consumer_note.WaitForNotification();
549 TF_EXPECT_OK(consumer_status);
550 producer_note.WaitForNotification();
551 TF_EXPECT_OK(producer_status);
552 ValidateResultTensor();
553
554 // Restart task 1 and check that recv from task 1 to task 0 fails.
555 RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1);
556 Notification post_restart_note;
557 rma_->RecvFromPeer(
558 "/job:worker/replica:0/task:1/device:" + dev_name, // peer_dev
559 "/job:worker/replica:0/task:1", // peer_task
560 false, // peer_is_local
561 buf_key, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
562 device_locality_, 0 /*dev_to_dev_stream_index*/,
563 nullptr /*cancellation_manager*/,
564 [&consumer_status, &post_restart_note](const Status& s) {
565 consumer_status = s;
566 post_restart_note.Notify();
567 });
568 post_restart_note.WaitForNotification();
569 EXPECT_TRUE(errors::IsFailedPrecondition(consumer_status));
570 }
571
TEST_P(CollRMADistTest,CheckHealthOKWithCachedAttr)572 TEST_P(CollRMADistTest, CheckHealthOKWithCachedAttr) {
573 ResolveDeviceAttributes();
574 Status check_health_status;
575 Notification check_health_done;
576 rma_->CheckPeerHealth(
577 "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
578 [&check_health_status, &check_health_done](const Status s) {
579 check_health_status = s;
580 check_health_done.Notify();
581 });
582 check_health_done.WaitForNotification();
583 TF_EXPECT_OK(check_health_status);
584 }
585
TEST_P(CollRMADistTest,CheckHealthOKWithoutCachedAttr)586 TEST_P(CollRMADistTest, CheckHealthOKWithoutCachedAttr) {
587 Status check_health_status;
588 Notification check_health_done;
589 rma_->CheckPeerHealth(
590 "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
591 [&check_health_status, &check_health_done](const Status s) {
592 check_health_status = s;
593 check_health_done.Notify();
594 });
595 check_health_done.WaitForNotification();
596 EXPECT_TRUE(check_health_status.ok());
597 }
598
TEST_P(CollRMADistTest,CheckHealthRestarted)599 TEST_P(CollRMADistTest, CheckHealthRestarted) {
600 ResolveDeviceAttributes();
601 RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1);
602
603 Status check_health_status;
604 Notification check_health_done;
605 rma_->CheckPeerHealth(
606 "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
607 [&check_health_status, &check_health_done](const Status s) {
608 check_health_status = s;
609 check_health_done.Notify();
610 });
611 check_health_done.WaitForNotification();
612 EXPECT_TRUE(errors::IsFailedPrecondition(check_health_status));
613 }
614
TEST_P(CollRMADistTest,CheckHealthFailedPeer)615 TEST_P(CollRMADistTest, CheckHealthFailedPeer) {
616 ResolveDeviceAttributes();
617 RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1,
618 /*is_failed*/ true);
619
620 Status check_health_status;
621 Notification check_health_done;
622 rma_->CheckPeerHealth(
623 "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
624 [&check_health_status, &check_health_done](const Status s) {
625 check_health_status = s;
626 check_health_done.Notify();
627 });
628 check_health_done.WaitForNotification();
629 EXPECT_TRUE(errors::IsUnavailable(check_health_status));
630 }
631
TEST_P(CollRMADistTest,CheckHealthRestartedWithDifferentDevices)632 TEST_P(CollRMADistTest, CheckHealthRestartedWithDifferentDevices) {
633 ResolveDeviceAttributes();
634 RestartWorker("/job:worker/replica:0/task:1", "GPU", /*num_devices*/ 1);
635 Status check_health_status;
636 Notification check_health_done;
637 rma_->CheckPeerHealth(
638 "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0,
639 [&check_health_status, &check_health_done](const Status s) {
640 check_health_status = s;
641 check_health_done.Notify();
642 });
643 check_health_done.WaitForNotification();
644 EXPECT_TRUE(errors::IsFailedPrecondition(check_health_status));
645 }
646
647 INSTANTIATE_TEST_SUITE_P(
648 TensorInBufPtrOrExtra, CollRMADistTest,
649 ::testing::Combine(::testing::Values(TEST_PARAM_TENSOR_LOC_AT_BUF_PTR,
650 TEST_PARAM_TENSOR_LOC_IN_EXTRA),
651 ::testing::Values(TEST_PARAM_DEVICE_TYPE_CPU,
652 TEST_PARAM_DEVICE_TYPE_GPU)));
653
654 } // namespace
655 } // namespace tensorflow
656