1 #include <torch/csrc/distributed/rpc/tensorpipe_agent.h>
2
3 #ifdef USE_TENSORPIPE
4
5 #include <limits>
6 #include <tuple>
7 #include <utility>
8
9 #include <fmt/format.h>
10 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated")
11 #include <tensorpipe/tensorpipe.h>
12 C10_DIAGNOSTIC_POP()
13
14 #include <torch/csrc/distributed/rpc/agent_utils.h>
15 #include <torch/csrc/distributed/rpc/tensorpipe_utils.h>
16 #include <torch/csrc/distributed/rpc/utils.h>
17
18 #include <c10/core/StreamGuard.h>
19 #include <c10/util/irange.h>
20
21 namespace torch::distributed::rpc {
22
23 namespace {
24
25 // An environment variable along the lines of GLOO_ and NCCL_SOCKET_IFNAME that
26 // allows the user to specify a device to bind to, instead of binding to the
27 // address that the hostname resolves to.
28 const std::string kSocketIfnameEnvVar = "TP_SOCKET_IFNAME";
29 const std::string kDefaultUvAddress = "127.0.0.1";
30
31 const std::string kGilAverageWaitTime = "agent.gil_average_wait_time_us";
32 const std::string kThreadPoolSize = "agent.thread_pool_size";
33 const std::string kNumIdleThreads = "agent.num_idle_threads";
34 const std::string kClientActiveCalls = "agent.client_active_calls";
35 const std::string kServerActiveCalls = "agent.server_active_calls";
36 const std::string kServerActiveAsyncCalls = "agent.server_active_async_calls";
37
getDevicesForTensors(const std::vector<torch::Tensor> & tensors,const DeviceMap & deviceMap,const std::string & remoteName)38 std::vector<c10::Device> getDevicesForTensors(
39 const std::vector<torch::Tensor>& tensors,
40 const DeviceMap& deviceMap,
41 const std::string& remoteName) {
42 // If the deviceMap is overridden, use that instead.
43 const auto errStr = c10::str(
44 "TensorPipe RPC backend only supports CPU tensors by default, please "
45 "move your tensors to CPU before sending them over RPC, or call "
46 "`set_device_map` on `TensorPipeRpcBackendOptions` to explicitly "
47 "configure device mapping. ",
48 "Request device mapping is not available for destination ",
49 remoteName);
50 std::vector<c10::Device> devices;
51 devices.reserve(tensors.size());
52 bool hasMappedDevice = false;
53 for (const auto& t : tensors) {
54 if (t.device().is_cpu()) {
55 const auto deviceIter = deviceMap.find(c10::kCPU);
56 if (deviceIter == deviceMap.end()) {
57 devices.emplace_back(c10::kCPU);
58 } else {
59 devices.emplace_back(deviceIter->second);
60 hasMappedDevice = true;
61 }
62 } else {
63 const auto deviceIter = deviceMap.find(t.device());
64 TORCH_CHECK(
65 deviceIter != deviceMap.end(),
66 errStr,
67 " for device ",
68 t.device(),
69 " but received a tensor on that device.");
70 devices.push_back(deviceIter->second);
71 hasMappedDevice = true;
72 }
73 }
74 if (!hasMappedDevice) {
75 devices.clear();
76 }
77 return devices;
78 }
79
getStreamsFromPoolForDevices(const std::vector<c10::Device> & devices)80 std::vector<c10::Stream> getStreamsFromPoolForDevices(
81 const std::vector<c10::Device>& devices) {
82 if (devices.empty()) {
83 return {};
84 }
85 c10::impl::VirtualGuardImpl impl(devices[0].type());
86 std::vector<c10::Stream> streams;
87 streams.reserve(devices.size());
88 for (const c10::Device& device : devices) {
89 TORCH_INTERNAL_ASSERT(device.type() == impl.type());
90 streams.push_back(impl.getStreamFromGlobalPool(device));
91 }
92 return streams;
93 }
94
getCurrentStreamsForDevices(const std::vector<c10::Device> & devices)95 std::vector<c10::Stream> getCurrentStreamsForDevices(
96 const std::vector<c10::Device>& devices) {
97 if (devices.empty()) {
98 return {};
99 }
100 c10::impl::VirtualGuardImpl impl(devices[0].type());
101 std::vector<c10::Stream> streams;
102 streams.reserve(devices.size());
103 for (const c10::Device& device : devices) {
104 TORCH_INTERNAL_ASSERT(device.type() == impl.type());
105 streams.push_back(impl.getStream(device));
106 }
107 return streams;
108 }
109
getDevicesOfTensors(const std::vector<torch::Tensor> & tensors)110 std::vector<c10::Device> getDevicesOfTensors(
111 const std::vector<torch::Tensor>& tensors) {
112 std::optional<c10::impl::VirtualGuardImpl> impl;
113 size_t deviceCount = 0;
114 std::vector<bool> indexBitset;
115 for (const torch::Tensor& tensor : tensors) {
116 if (!tensor.is_cpu()) {
117 c10::Device device = tensor.device();
118 if (!impl.has_value()) {
119 impl.emplace(device.type());
120 indexBitset.resize(impl->deviceCount());
121 }
122 TORCH_INTERNAL_ASSERT(device.type() == impl->type());
123 TORCH_INTERNAL_ASSERT(device.has_index());
124 if (!indexBitset[device.index()]) {
125 deviceCount++;
126 indexBitset[device.index()] = true;
127 }
128 }
129 }
130 std::vector<c10::Device> devices;
131 devices.reserve(deviceCount);
132 for (const auto idx : c10::irange(indexBitset.size())) {
133 if (indexBitset[idx]) {
134 devices.emplace_back(impl->type(), static_cast<c10::DeviceIndex>(idx));
135 }
136 }
137 return devices;
138 }
139
makeStreamsWaitOnOthers(const std::vector<c10::Stream> & consumers,const std::vector<c10::Stream> & producers)140 void makeStreamsWaitOnOthers(
141 const std::vector<c10::Stream>& consumers,
142 const std::vector<c10::Stream>& producers) {
143 for (const c10::Stream& producer : producers) {
144 const c10::Stream& consumer =
145 getStreamForDevice(consumers, producer.device());
146 c10::Event event(producer.device_type());
147 event.record(producer);
148 event.block(consumer);
149 }
150 }
151
152 } // namespace
153
154 C10_DEFINE_REGISTRY_WITHOUT_WARNING(
155 TensorPipeTransportRegistry,
156 TransportRegistration);
157
158 C10_DEFINE_REGISTRY_WITHOUT_WARNING(
159 TensorPipeChannelRegistry,
160 ChannelRegistration);
161
guessAddress()162 const std::string& TensorPipeAgent::guessAddress() {
163 static const std::string uvAddress = []() {
164 char* ifnameEnv = std::getenv(kSocketIfnameEnvVar.c_str());
165 if (ifnameEnv != nullptr) {
166 auto [error, result] =
167 tensorpipe::transport::uv::lookupAddrForIface(ifnameEnv);
168 if (error) {
169 LOG(WARNING) << "Failed to look up the IP address for interface "
170 << ifnameEnv << " (" << error.what() << "), defaulting to "
171 << kDefaultUvAddress;
172 return kDefaultUvAddress;
173 }
174 return result;
175 }
176 auto [error, result] = tensorpipe::transport::uv::lookupAddrForHostname();
177 if (error) {
178 LOG(WARNING) << "Failed to look up the IP address for the hostname ("
179 << error.what() << "), defaulting to " << kDefaultUvAddress;
180 return kDefaultUvAddress;
181 }
182 return result;
183 }();
184 return uvAddress;
185 }
186
187 namespace {
188
makeUvTransport()189 std::unique_ptr<TransportRegistration> makeUvTransport() {
190 auto context = tensorpipe::transport::uv::create();
191 std::string address = TensorPipeAgent::guessAddress();
192 return std::make_unique<TransportRegistration>(TransportRegistration{
193 std::move(context), kUvTransportPriority, std::move(address)});
194 }
195
196 // The UV transport is implemented using standard TCP connections. It leverages
197 // libuv (https://github.com/libuv/libuv) in order to be cross-platform.
198 C10_REGISTER_CREATOR(TensorPipeTransportRegistry, uv, makeUvTransport);
199
200 #if TENSORPIPE_HAS_SHM_TRANSPORT
201
makeShmTransport()202 std::unique_ptr<TransportRegistration> makeShmTransport() {
203 auto context = tensorpipe::transport::shm::create();
204 return std::make_unique<TransportRegistration>(
205 TransportRegistration{std::move(context), kShmTransportPriority, ""});
206 }
207
208 // The SHM implements connections using ringbuffers residing in anonymous shared
209 // memory (plus UNIX domain sockets to bootstrap the connection and exchange
210 // file descriptors). It is Linux-only due to some advanced features (O_TMPFILE,
211 // eventfd, ...).
212 C10_REGISTER_CREATOR(TensorPipeTransportRegistry, shm, makeShmTransport);
213
214 #endif // TENSORPIPE_HAS_SHM_TRANSPORT
215
216 #if TENSORPIPE_HAS_IBV_TRANSPORT
217
makeIbvTransport()218 std::unique_ptr<TransportRegistration> makeIbvTransport() {
219 auto context = tensorpipe::transport::ibv::create();
220 std::string address = TensorPipeAgent::guessAddress();
221 return std::make_unique<TransportRegistration>(TransportRegistration{
222 std::move(context), kIbvTransportPriority, std::move(address)});
223 }
224
225 // The IBV transport sends data across using an InfiniBand queue pair, locally
226 // copying data to and from a staging buffer (registered with libibverbs) and
227 // issuing a RDMA write for transferring data across machines (plus a send for
228 // acknowledging it). It bootstraps using a standard TCP connection to exchange
229 // setup information. It is Linux-only.
230 C10_REGISTER_CREATOR(TensorPipeTransportRegistry, ibv, makeIbvTransport);
231
232 #endif // TENSORPIPE_HAS_IBV_TRANSPORT
233
makeBasicChannel()234 std::unique_ptr<ChannelRegistration> makeBasicChannel() {
235 auto context = tensorpipe::channel::basic::create();
236 return std::make_unique<ChannelRegistration>(
237 ChannelRegistration{std::move(context), kBasicChannelPriority});
238 }
239
240 // The basic channel is just a straightforward adapter wrapper that allows any
241 // transport to be used as a channel.
242 C10_REGISTER_CREATOR(TensorPipeChannelRegistry, basic, makeBasicChannel);
243
244 #if TENSORPIPE_HAS_CMA_CHANNEL
245
makeCmaChannel()246 std::unique_ptr<ChannelRegistration> makeCmaChannel() {
247 auto context = tensorpipe::channel::cma::create();
248 return std::make_unique<ChannelRegistration>(
249 ChannelRegistration{std::move(context), kCmaChannelPriority});
250 }
251
252 // The CMA channel uses the Linux cross-memory attach syscalls (process_vm_readv
253 // and _writev), which allow one process to access the private memory of another
254 // process (as long as they belong to the same user and other security
255 // constraints are satisfied). It does, more or less, what GDB does when it's
256 // attached to a running process.
257 C10_REGISTER_CREATOR(TensorPipeChannelRegistry, cma, makeCmaChannel);
258
259 #endif // TENSORPIPE_HAS_CMA_CHANNEL
260
261 constexpr static int kNumUvThreads = 16;
262
makeMultiplexedUvChannel()263 std::unique_ptr<ChannelRegistration> makeMultiplexedUvChannel() {
264 std::vector<std::shared_ptr<tensorpipe::transport::Context>> contexts;
265 std::vector<std::shared_ptr<tensorpipe::transport::Listener>> listeners;
266 for (const auto laneIdx C10_UNUSED : c10::irange(kNumUvThreads)) {
267 auto context = tensorpipe::transport::uv::create();
268 std::string address = TensorPipeAgent::guessAddress();
269 contexts.push_back(std::move(context));
270 listeners.push_back(contexts.back()->listen(address));
271 }
272 auto context = tensorpipe::channel::mpt::create(
273 std::move(contexts), std::move(listeners));
274 return std::make_unique<ChannelRegistration>(
275 ChannelRegistration{std::move(context), kMultiplexedUvChannelPriority});
276 }
277
278 // The multiplexed UV channel encapsulates multiple UV transports (each with its
279 // own event loop thread). Each channel will, in turn, contain multiple UV
280 // connections, one for each of those contexts. When sending a tensor, its data
281 // is split in equal chunks and each chunks is sent on a different connection
282 // and thus driven by a different thread. This is needed to reach very high
283 // bandwidths.
284 C10_REGISTER_CREATOR(
285 TensorPipeChannelRegistry,
286 mpt_uv,
287 makeMultiplexedUvChannel);
288
289 } // namespace
290
291 ////////////////////////// MetricsTracker /////////////////////////////////
292
TimeSeriesMetricsTracker(uint64_t currentSum,uint64_t currentCount)293 TensorPipeAgent::TimeSeriesMetricsTracker::TimeSeriesMetricsTracker(
294 uint64_t currentSum,
295 uint64_t currentCount)
296 : currentSum_(currentSum), currentCount_(currentCount) {}
297
addData(uint64_t dataPoint)298 void TensorPipeAgent::TimeSeriesMetricsTracker::addData(uint64_t dataPoint) {
299 currentSum_ += dataPoint;
300 ++currentCount_;
301 }
302
computeAverage() const303 float TensorPipeAgent::TimeSeriesMetricsTracker::computeAverage() const {
304 return currentCount_ == 0 ? 0 : currentSum_ / (float)currentCount_;
305 }
306
307 //////////////////////// TensorpipeRpcAgent /////////////////////////////////
308
removeFromTimeoutMap(uint64_t messageId)309 void TensorPipeAgent::removeFromTimeoutMap(uint64_t messageId) {
310 // Remove entry from timeoutMap_.
311 {
312 std::unique_lock<std::mutex> lock(timeoutMapMutex_);
313 auto it = messageIdToTimeout_.find(messageId);
314 if (it == messageIdToTimeout_.end()) {
315 // Already removed from the map by pollTimeoutRpcs(), no need to
316 // process further.
317 return;
318 }
319
320 auto& expirationTime = it->second;
321
322 auto& timedOutFuturesVector = timeoutMap_[expirationTime];
323 for (auto it = timedOutFuturesVector.begin();
324 it != timedOutFuturesVector.end();
325 it++) {
326 if (it->messageId == messageId) {
327 it = timedOutFuturesVector.erase(it);
328 break;
329 }
330 }
331
332 if (timedOutFuturesVector.empty()) {
333 timeoutMap_.erase(expirationTime);
334 }
335
336 // Remove from messageId to timeout map as well.
337 messageIdToTimeout_.erase(messageId);
338 }
339 }
340
prepareNames(bool isStaticGroup)341 void TensorPipeAgent::prepareNames(bool isStaticGroup) {
342 std::unordered_map<std::string, worker_id_t> nameToId;
343 if (isStaticGroup) {
344 nameToId = collectNames(
345 rankToNameStore_, workerInfo_.id_, workerInfo_.name_, worldSize_);
346 } else {
347 nameToId = collectCurrentNames(
348 rankToNameStore_, workerInfo_.id_, workerInfo_.name_);
349 }
350
351 for (const auto& entry : nameToId) {
352 const auto& workerName = entry.first;
353 const auto& workerId = entry.second;
354 workerIdToInfo_.emplace(workerId, WorkerInfo(workerName, workerId));
355 workerNameToInfo_.emplace(workerName, WorkerInfo(workerName, workerId));
356 }
357 }
358
checkAndSetStaticGroup(const c10::intrusive_ptr<::c10d::Store> & store)359 void TensorPipeAgent::checkAndSetStaticGroup(
360 const c10::intrusive_ptr<::c10d::Store>& store) {
361 std::string isStaticGroupKey("rpcIsStaticGroup");
362
363 std::string isStaticGroupStr = isStaticGroup_ ? "true" : "false";
364 std::vector<uint8_t> isStaticGroupVec(
365 (uint8_t*)isStaticGroupStr.c_str(),
366 (uint8_t*)isStaticGroupStr.c_str() + isStaticGroupStr.length());
367 std::vector<uint8_t> returnedVec;
368 returnedVec = store->compareSet(
369 isStaticGroupKey, std::vector<uint8_t>(), isStaticGroupVec);
370 std::string returnedVal = std::string(returnedVec.begin(), returnedVec.end());
371 // In both cases, the returned value should be the value of isStaticGroupStr,
372 // otherwise there is a discrepency with initialization among one of the
373 // members
374 TORCH_CHECK(
375 returnedVal == isStaticGroupStr,
376 fmt::format(
377 "RPC group mixes statically and dynamically initialized members which is not supported. ",
378 "Static group property is initialized as {} and is trying to be set as {} ",
379 isStaticGroup_,
380 returnedVal));
381 }
382
TensorPipeAgent(const c10::intrusive_ptr<::c10d::Store> & store,std::string selfName,worker_id_t selfId,std::optional<int> worldSize,TensorPipeRpcBackendOptions opts,std::unordered_map<std::string,DeviceMap> reverseDeviceMaps,std::vector<c10::Device> devices,std::unique_ptr<RequestCallback> cb)383 TensorPipeAgent::TensorPipeAgent(
384 const c10::intrusive_ptr<::c10d::Store>& store,
385 std::string selfName,
386 worker_id_t selfId,
387 std::optional<int> worldSize,
388 TensorPipeRpcBackendOptions opts,
389 std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
390 std::vector<c10::Device> devices,
391 std::unique_ptr<RequestCallback> cb)
392 : RpcAgent(
393 WorkerInfo(std::move(selfName), selfId),
394 std::move(cb),
395 std::chrono::milliseconds(
396 (long)(opts.rpcTimeoutSeconds * kSecToMsConversion))),
397 isStaticGroup_(worldSize.has_value()),
398 store_(store),
399 opts_(std::move(opts)),
400 reverseDeviceMaps_(std::move(reverseDeviceMaps)),
401 devices_(std::move(devices)),
402 threadPool_(opts_.numWorkerThreads),
403 context_(std::make_shared<tensorpipe::Context>(
404 tensorpipe::ContextOptions().name(workerInfo_.name_))),
405 rankToNameStore_("names", store),
406 nameToAddressStore_("addrs", store),
407 shutdownStore_("shutdown", store) {
408 if (isStaticGroup_) {
409 worldSize_ = worldSize.value();
410 }
411
412 // check the static group attribute against store
413 checkAndSetStaticGroup(store);
414
415 // collect worker names
416 prepareNames(isStaticGroup_);
417
418 // Initialize the time-series metrics tracking map
419 timeSeriesMetrics_.emplace(kGilAverageWaitTime, TimeSeriesMetricsTracker());
420 }
421
~TensorPipeAgent()422 TensorPipeAgent::~TensorPipeAgent() {
423 VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is being destroyed";
424 shutdown();
425 }
426
startImpl()427 void TensorPipeAgent::startImpl() {
428 VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is starting";
429
430 std::vector<std::string> addresses;
431 int lowestPriority = std::numeric_limits<int>::max();
432 std::string lowestPriorityTransport;
433
434 // Register transports
435 for (auto& key : TensorPipeTransportRegistry()->Keys()) {
436 int64_t priority = -1;
437 if (opts_.transports.has_value()) {
438 auto iter =
439 std::find(opts_.transports->begin(), opts_.transports->end(), key);
440 if (iter == opts_.transports->end()) {
441 continue;
442 }
443 // Assign priorities in reverse order of occurrence in the vector, so that
444 // a transport that comes before another receives a higher priority.
445 priority =
446 opts_.transports->size() - 1 - (iter - opts_.transports->begin());
447 }
448 std::unique_ptr<TransportRegistration> reg =
449 TensorPipeTransportRegistry()->Create(key);
450 if (!reg->transport->isViable()) {
451 continue;
452 }
453 if (priority == -1) {
454 priority = reg->priority;
455 }
456 if (priority < lowestPriority) {
457 lowestPriority = priority;
458 lowestPriorityTransport = key;
459 }
460 addresses.push_back(c10::str(key, "://", reg->address));
461 context_->registerTransport(priority, key, reg->transport);
462 }
463
464 // Register channels
465 for (auto& key : TensorPipeChannelRegistry()->Keys()) {
466 int64_t priority = -1;
467 if (opts_.channels.has_value()) {
468 auto iter =
469 std::find(opts_.channels->begin(), opts_.channels->end(), key);
470 if (iter == opts_.channels->end()) {
471 continue;
472 }
473 // Assign priorities in reverse order of occurrence in the vector, so
474 // that a channel that comes before another receives a higher priority.
475 priority = opts_.channels->size() - 1 - (iter - opts_.channels->begin());
476 }
477 std::unique_ptr<ChannelRegistration> reg =
478 TensorPipeChannelRegistry()->Create(key);
479 if (!reg->channel->isViable()) {
480 continue;
481 }
482 if (priority == -1) {
483 priority = reg->priority;
484 }
485 context_->registerChannel(priority, key, reg->channel);
486 }
487
488 listener_ = context_->listen(addresses);
489
490 // Store our own url.
491 const auto address = listener_->url(lowestPriorityTransport);
492 nameToAddressStore_.set(workerInfo_.name_, address);
493
494 VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is using address "
495 << address;
496
497 for (const auto& p : workerNameToInfo_) {
498 const auto& name = p.first;
499 auto nodeAddrData = nameToAddressStore_.get(name);
500 auto nodeAddrStr =
501 std::string((const char*)nodeAddrData.data(), nodeAddrData.size());
502 workerNameToURL_.insert({name, nodeAddrStr});
503 }
504
505 // Start the Timeout Thread
506 timeoutThread_ = std::thread(&TensorPipeAgent::pollTimeoutRpcs, this);
507
508 listener_->accept([this](
509 const tensorpipe::Error& error,
510 std::shared_ptr<tensorpipe::Pipe> pipe) {
511 onListenerAccepted(error, pipe);
512 });
513 }
514
onListenerAccepted(const tensorpipe::Error & error,std::shared_ptr<tensorpipe::Pipe> & pipe)515 void TensorPipeAgent::onListenerAccepted(
516 const tensorpipe::Error& error,
517 std::shared_ptr<tensorpipe::Pipe>& pipe) {
518 if (error) {
519 if (error.isOfType<tensorpipe::ListenerClosedError>() &&
520 !rpcAgentRunning_.load()) {
521 // This is expected.
522 } else {
523 LOG(WARNING) << "RPC agent for " << workerInfo_.name_
524 << " encountered error when accepting incoming pipe: "
525 << error.what();
526 }
527 return;
528 }
529
530 // Accept the next connection request
531 listener_->accept([this](
532 const tensorpipe::Error& error,
533 std::shared_ptr<tensorpipe::Pipe> pipe) {
534 onListenerAccepted(error, pipe);
535 });
536
537 VLOG(1) << "RPC agent for " << workerInfo_.name_
538 << " accepted incoming pipe from " << pipe->getRemoteName();
539
540 // Arm for server read
541 respond(pipe);
542 }
543
pipeRead(const std::shared_ptr<tensorpipe::Pipe> & pipe,std::function<void (const tensorpipe::Error &,c10::intrusive_ptr<Message>,std::vector<c10::Stream>)> fn)544 void TensorPipeAgent::pipeRead(
545 const std::shared_ptr<tensorpipe::Pipe>& pipe,
546 std::function<void(
547 const tensorpipe::Error&,
548 c10::intrusive_ptr<Message>,
549 std::vector<c10::Stream>)> fn) noexcept {
550 pipe->readDescriptor([this, fn{std::move(fn)}, pipe](
551 const tensorpipe::Error& error,
552 tensorpipe::Descriptor tpDescriptor) mutable {
553 if (error) {
554 fn(error, c10::intrusive_ptr<Message>(), {});
555 return;
556 }
557
558 std::vector<c10::Stream> streams;
559 {
560 GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
561 streams = getStreamsFromPoolForDevices(devices_);
562 }
563 auto [tpAllocation, tpBuffers] = tensorpipeAllocate(tpDescriptor, streams);
564
565 pipe->read(
566 std::move(tpAllocation),
567 [tpDescriptor{std::move(tpDescriptor)},
568 tpBuffers{
569 std::make_shared<TensorpipeReadBuffers>(std::move(tpBuffers))},
570 fn{std::move(fn)},
571 streams{std::move(streams)}](const tensorpipe::Error& error) mutable {
572 if (error) {
573 fn(error, c10::intrusive_ptr<Message>(), {});
574 return;
575 }
576
577 // FIXME This does some unpickling, which could be a bit expensive:
578 // perhaps it would be best to perform it inside the worker threads?
579 c10::intrusive_ptr<Message> rpcMessage = tensorpipeDeserialize(
580 std::move(tpDescriptor), std::move(*tpBuffers));
581
582 fn(error, std::move(rpcMessage), std::move(streams));
583 });
584 });
585 }
586
pipeWrite(const std::shared_ptr<tensorpipe::Pipe> & pipe,c10::intrusive_ptr<Message> rpcMessage,std::vector<c10::Device> && devices,std::vector<c10::Stream> streams,std::function<void (const tensorpipe::Error &)> fn)587 void TensorPipeAgent::pipeWrite(
588 const std::shared_ptr<tensorpipe::Pipe>& pipe,
589 c10::intrusive_ptr<Message> rpcMessage,
590 std::vector<c10::Device>&& devices,
591 std::vector<c10::Stream> streams,
592 std::function<void(const tensorpipe::Error&)> fn) noexcept {
593 auto [tpMessage, tpBuffers] =
594 tensorpipeSerialize(std::move(rpcMessage), std::move(devices), streams);
595
596 pipe->write(
597 std::move(tpMessage),
598 [tpBuffers{
599 std::make_shared<TensorpipeWriteBuffers>(std::move(tpBuffers))},
600 fn{std::move(fn)},
601 streams{std::move(streams)}](const tensorpipe::Error& error) {
602 fn(error);
603 });
604 }
605
sendCompletedResponseMessage(std::shared_ptr<tensorpipe::Pipe> & pipe,JitFuture & futureResponseMessage,uint64_t messageId,std::vector<c10::Stream> streams)606 void TensorPipeAgent::sendCompletedResponseMessage(
607 std::shared_ptr<tensorpipe::Pipe>& pipe,
608 JitFuture& futureResponseMessage,
609 uint64_t messageId,
610 std::vector<c10::Stream> streams) {
611 if (!rpcAgentRunning_.load()) {
612 LOG(WARNING) << "RPC agent for " << workerInfo_.name_
613 << " won't send response to request #" << messageId << " to "
614 << pipe->getRemoteName() << ", as the agent is shutting down";
615 return;
616 }
617
618 VLOG(1) << "RPC agent for " << workerInfo_.name_
619 << " is sending response to request #" << messageId << " to "
620 << pipe->getRemoteName();
621
622 if (!futureResponseMessage.hasError()) {
623 c10::intrusive_ptr<Message> responseMessage =
624 futureResponseMessage.value().toCustomClass<Message>();
625 responseMessage->setId(messageId);
626
627 std::vector<c10::Device> devices;
628 try {
629 devices = getDevicesForRemote(pipe->getRemoteName(), *responseMessage);
630 } catch (const std::exception& e) {
631 responseMessage = createExceptionResponse(e.what(), messageId);
632 }
633
634 for (const auto& tensor : responseMessage->tensors()) {
635 const auto device = tensor.device();
636 if (!device.is_cpu()) {
637 GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
638 if (std::find(devices_.begin(), devices_.end(), device) ==
639 devices_.end()) {
640 std::ostringstream oss;
641 std::copy(
642 devices_.begin(),
643 devices_.end(),
644 std::ostream_iterator<c10::Device>(oss, ", "));
645 responseMessage = createExceptionResponse(
646 c10::str(
647 "RPC detected that a user-function output tensor on device ",
648 device,
649 ". This device is not one of the input tensor devices: ",
650 oss.str(),
651 "which is not yet supported. Please file a feature request "
652 "issue in PyTorch GitHub repo."),
653 messageId);
654 break;
655 }
656 }
657 }
658
659 pipeWrite(
660 pipe,
661 std::move(responseMessage),
662 std::move(devices),
663 std::move(streams),
664 [this, pipe, messageId](const tensorpipe::Error& error) {
665 if (error) {
666 LOG(WARNING)
667 << "RPC agent for " << workerInfo_.name_
668 << " encountered error when sending response to request #"
669 << messageId << " to " << pipe->getRemoteName() << ": "
670 << error.what();
671 return;
672 }
673
674 VLOG(1) << "RPC agent for " << workerInfo_.name_
675 << " done sending response to request #" << messageId
676 << " to " << pipe->getRemoteName();
677 });
678 } else {
679 pipeWrite(
680 pipe,
681 createExceptionResponse(
682 futureResponseMessage.tryRetrieveErrorMessage(), messageId),
683 /* devices */ {},
684 std::move(streams),
685 [this, pipe, messageId](const tensorpipe::Error& error) {
686 if (error) {
687 LOG(WARNING)
688 << "RPC agent for " << workerInfo_.name_
689 << " encountered error when sending response to request #"
690 << messageId << " to " << pipe->getRemoteName() << ": "
691 << error.what();
692 return;
693 }
694
695 VLOG(1) << "RPC agent for " << workerInfo_.name_
696 << " done sending response to request #" << messageId
697 << " to " << pipe->getRemoteName();
698 });
699 }
700 }
701
respond(std::shared_ptr<tensorpipe::Pipe> & pipe)702 void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
703 pipeRead(
704 pipe,
705 [this, pipe](
706 const tensorpipe::Error& error,
707 c10::intrusive_ptr<Message> requestMessage,
708 std::vector<c10::Stream> streams) mutable {
709 if (error) {
710 if (shuttingDown_) {
711 // This is expected.
712 } else {
713 LOG(WARNING)
714 << "RPC agent for " << workerInfo_.name_
715 << " encountered error when reading incoming request from "
716 << pipe->getRemoteName() << ": " << error.what();
717 }
718 return;
719 }
720
721 // Arm for next read
722 respond(pipe);
723
724 uint64_t messageId = requestMessage->id();
725 increaseCallCount(serverActiveCalls_);
726
727 VLOG(1) << "RPC agent for " << workerInfo_.name_
728 << " received request #" << messageId << " from "
729 << pipe->getRemoteName();
730
731 // Defer user RPC UDF run to thread pool
732 threadPool_.run([this,
733 pipe,
734 messageId,
735 requestMessage{std::move(requestMessage)},
736 streams{std::move(streams)}]() mutable {
737 VLOG(1) << "RPC agent for " << workerInfo_.name_
738 << " is running request #" << messageId << " from "
739 << pipe->getRemoteName() << " in thread pool";
740
741 c10::intrusive_ptr<JitFuture> futureResponseMessage;
742 try {
743 // Instead of creating a MultiStreamGuard here, the ctx is passed
744 // to the callback and the MultiStreamGuard is created there,
745 // because subsequent processing can switch threads due to 1)
746 // waiting for RRef arguments to become ready 2) async_execution.
747 // Besides, the `ctx` also needs to be propagated to
748 // `process***Call` methods to synchronize CUDA streams there
749 // to make sure that we fetch the correct value from `to_here()`
750 // call.
751 futureResponseMessage =
752 cb_->operator()(*requestMessage, std::move(streams));
753 } catch (const std::exception& /* unused */) {
754 futureResponseMessage =
755 c10::make_intrusive<JitFuture>(at::AnyClassType::get());
756 futureResponseMessage->setError(std::current_exception());
757 }
758
759 increaseCallCount(serverActiveAsyncCalls_);
760 futureResponseMessage->addCallback(
761 [this, pipe, messageId](
762 JitFuture& futureResponseMessage) mutable {
763 decreaseCallCount(serverActiveCalls_);
764 decreaseCallCount(serverActiveAsyncCalls_);
765 auto streams = getCurrentStreamsForDevices(
766 futureResponseMessage.devices());
767 sendCompletedResponseMessage(
768 pipe, futureResponseMessage, messageId, std::move(streams));
769 });
770
771 VLOG(1) << "RPC agent for " << workerInfo_.name_
772 << " done running request #" << messageId << " from "
773 << pipe->getRemoteName() << " in thread pool";
774 });
775 });
776 }
777
send(const WorkerInfo & toWorkerInfo,c10::intrusive_ptr<Message> requestMessage,const float rpcTimeoutSeconds,const DeviceMap & deviceMap)778 c10::intrusive_ptr<JitFuture> TensorPipeAgent::send(
779 const WorkerInfo& toWorkerInfo,
780 c10::intrusive_ptr<Message> requestMessage,
781 const float rpcTimeoutSeconds,
782 const DeviceMap& deviceMap) {
783 TORCH_CHECK(
784 requestMessage->isRequest(),
785 "TensorPipeAgent::send(..) is only for sending requests.");
786
787 if (!rpcAgentRunning_.load()) {
788 auto err = c10::str(
789 "Node ",
790 RpcAgent::getWorkerInfo().id_,
791 "tried to send() a message of type ",
792 requestMessage->type(),
793 " but RPC is no longer running on this node.");
794 TORCH_CHECK(false, err);
795 }
796
797 const auto& url = findWorkerURL(toWorkerInfo);
798
799 decltype(connectedPipes_)::iterator it;
800 {
801 std::unique_lock<std::mutex> lock(connectedPipesMutex_);
802
803 // See if we already have a connection to this address or not
804 it = connectedPipes_.find(toWorkerInfo.id_);
805 if (it == connectedPipes_.end()) {
806 // An instance of ClientPipe cannot be copied or moved as it contains a
807 // mutex, and to force in-place construction in GCC 5 we need piecewise
808 // construction in order to work around an issue.
809 it = connectedPipes_
810 .emplace(
811 std::piecewise_construct,
812 std::forward_as_tuple(toWorkerInfo.id_),
813 std::forward_as_tuple(context_->connect(
814 url,
815 tensorpipe::PipeOptions().remoteName(
816 toWorkerInfo.name_))))
817 .first;
818 }
819 }
820 ClientPipe& clientPipe = it->second;
821
822 std::shared_ptr<torch::distributed::rpc::TensorPipeAgent::AtomicJitFuture>
823 futureResponseMessage;
824 {
825 GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
826 futureResponseMessage = std::make_shared<AtomicJitFuture>(devices_);
827 }
828 uint64_t messageId = nextMessageID_++;
829 requestMessage->setId(messageId);
830
831 {
832 std::unique_lock<std::mutex> lock(clientPipe.mutex_);
833 clientPipe.pendingResponseMessage_[messageId] = futureResponseMessage;
834 }
835
836 // Get devices for tensors in the request message. This can throw if device
837 // maps are not configured properly for this request.
838 std::vector<c10::Device> devices;
839 if (deviceMap.empty()) {
840 devices =
841 getDevicesForRemote(clientPipe.pipe_->getRemoteName(), *requestMessage);
842 } else {
843 // If deviceMap is specified, use that instead.
844 devices = getDevicesForTensors(
845 requestMessage->tensors(),
846 deviceMap,
847 clientPipe.pipe_->getRemoteName());
848 }
849
850 futureResponseMessage->jitFuture->addCallback(
851 [this](JitFuture& /* unused */) {
852 TORCH_INTERNAL_ASSERT(
853 this->threadPool_.inThreadPool(),
854 "Future marked complete from outside the thread pool");
855 });
856
857 increaseCallCount(clientActiveCalls_);
858 // Use the default RPC timeout if no timeout is specified for this send call
859 auto timeout = rpcTimeoutSeconds == kUnsetRpcTimeout
860 ? getRpcTimeout()
861 : std::chrono::milliseconds(
862 static_cast<int>(rpcTimeoutSeconds * kSecToMsConversion));
863
864 // We only add to the timeoutMap_ if the timeout is not 0. Per our
865 // documentation, a user-provided timeout of 0 indicates the RPC should never
866 // expire (infinite timeout), so there is no need to track it in the
867 // timeoutMap_.
868 steady_clock_time_point expirationTime;
869 if (timeout.count() != 0) {
870 // Compute the expiration time for this message based on the timeout
871 expirationTime = computeRpcMessageExpiryTime(timeout);
872
873 // Add the Future to the right vector in the timeoutMap_
874 {
875 std::unique_lock<std::mutex> lock(timeoutMapMutex_);
876 auto& timeoutFuturesVector = timeoutMap_[expirationTime];
877 messageIdToTimeout_.emplace(messageId, expirationTime);
878 timeoutFuturesVector.emplace_back(
879 messageId, futureResponseMessage, timeout);
880 }
881 timeoutThreadCV_.notify_one();
882 }
883
884 VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is sending request #"
885 << messageId << " to " << clientPipe.pipe_->getRemoteName();
886
887 std::vector<c10::Stream> streams;
888 {
889 GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
890 streams = getStreamsFromPoolForDevices(devices_);
891 }
892 makeStreamsWaitOnOthers(
893 streams,
894 getCurrentStreamsForDevices(
895 getDevicesOfTensors(requestMessage->tensors())));
896 pipeWrite(
897 clientPipe.pipe_,
898 std::move(requestMessage),
899 std::move(devices),
900 std::move(streams),
901 [this, &clientPipe, messageId](const tensorpipe::Error& error) mutable {
902 if (error) {
903 if (error.isOfType<tensorpipe::PipeClosedError>() &&
904 !rpcAgentRunning_.load()) {
905 // This is expected.
906 } else {
907 LOG(WARNING) << "RPC agent for " << workerInfo_.name_
908 << " encountered error when sending outgoing request #"
909 << messageId << " to "
910 << clientPipe.pipe_->getRemoteName() << ": "
911 << error.what();
912 }
913 handleClientError(clientPipe, error);
914 return;
915 }
916
917 VLOG(1) << "RPC agent for " << workerInfo_.name_ << " sent request #"
918 << messageId << " to " << clientPipe.pipe_->getRemoteName();
919
920 pipeRead(
921 clientPipe.pipe_,
922 [this, &clientPipe](
923 const tensorpipe::Error& error,
924 c10::intrusive_ptr<Message> responseMessage,
925 std::vector<c10::Stream> streams) {
926 if (error) {
927 if (error.isOfType<tensorpipe::PipeClosedError>() &&
928 !rpcAgentRunning_.load()) {
929 // This is expected.
930 } else {
931 LOG(WARNING)
932 << "RPC agent for " << workerInfo_.name_
933 << " encountered error when reading incoming response from "
934 << clientPipe.pipe_->getRemoteName() << ": "
935 << error.what();
936 }
937 handleClientError(clientPipe, error);
938 return;
939 }
940
941 // Identify future response message by message ID
942 uint64_t messageId = responseMessage->id();
943
944 VLOG(1) << "RPC agent for " << workerInfo_.name_
945 << " received response #" << messageId << " from "
946 << clientPipe.pipe_->getRemoteName();
947
948 std::shared_ptr<AtomicJitFuture> futureResponseMessage;
949 {
950 std::lock_guard<std::mutex> lock(clientPipe.mutex_);
951 // A read error will lead all following callbacks to be
952 // invoked with error, and shouldn't reach here.
953 TORCH_INTERNAL_ASSERT(
954 !clientPipe.inError_, "Shouldn't be in error state");
955 auto it = clientPipe.pendingResponseMessage_.find(messageId);
956 TORCH_INTERNAL_ASSERT(
957 it != clientPipe.pendingResponseMessage_.end(),
958 "message ID ",
959 messageId,
960 " is not recognized");
961 futureResponseMessage = std::move(it->second);
962 clientPipe.pendingResponseMessage_.erase(it);
963 }
964
965 // Remove entry from timeoutMap_.
966 removeFromTimeoutMap(messageId);
967
968 if (responseMessage->type() == MessageType::EXCEPTION) {
969 markFutureWithError(
970 std::move(futureResponseMessage),
971 std::string(
972 responseMessage->payload().begin(),
973 responseMessage->payload().end()));
974 } else {
975 markFutureAsComplete(
976 std::move(futureResponseMessage),
977 std::move(responseMessage),
978 std::move(streams));
979 }
980 });
981 });
982
983 return futureResponseMessage->jitFuture;
984 }
985
handleClientError(ClientPipe & clientPipe,const tensorpipe::Error & error)986 void TensorPipeAgent::handleClientError(
987 ClientPipe& clientPipe,
988 const tensorpipe::Error& error) {
989 // When an error occurs on a pipe all pending operations will be aborted and
990 // all callbacks invoked with error, hence we immediately flush all future
991 // messages belonging to this pipe.
992 decltype(clientPipe.pendingResponseMessage_) pendingMsgs;
993 {
994 std::lock_guard<std::mutex> lock(clientPipe.mutex_);
995 std::swap(clientPipe.pendingResponseMessage_, pendingMsgs);
996 clientPipe.inError_ = true;
997 }
998 std::string errorMsg = error.what();
999 for (auto& p : pendingMsgs) {
1000 markFutureWithError(std::move(p.second), errorMsg);
1001
1002 // Remove entry from timeoutMap_.
1003 removeFromTimeoutMap(p.first);
1004 }
1005 }
1006
pollTimeoutRpcs()1007 void TensorPipeAgent::pollTimeoutRpcs() {
1008 while (rpcAgentRunning_.load()) {
1009 std::unique_lock<std::mutex> lock(timeoutMapMutex_);
1010
1011 // We sleep until the earliest expiring RPC in the timeoutMap_. We must
1012 // also ensure that we sleep while the map is empty, and we exit sleeping
1013 // if the RPC Agent has been shutdown.
1014 for (;;) {
1015 if (!rpcAgentRunning_.load()) {
1016 return;
1017 }
1018
1019 if (!timeoutMap_.empty()) {
1020 steady_clock_time_point earliestTimeout = timeoutMap_.begin()->first;
1021 if (std::chrono::steady_clock::now() >= earliestTimeout) {
1022 break;
1023 }
1024 timeoutThreadCV_.wait_until(lock, earliestTimeout);
1025 } else {
1026 timeoutThreadCV_.wait(lock);
1027 }
1028 }
1029
1030 // Move all these futures to a separate vector so we can process them
1031 // outside the lock.
1032 std::vector<TimeoutMessageMetadata> timedOutFutures =
1033 std::move(timeoutMap_.begin()->second);
1034
1035 // We can safely remove this key from the timeoutMap_ since all these
1036 // futures will be processed.
1037 timeoutMap_.erase(timeoutMap_.begin());
1038
1039 for (auto& timeoutMetadata : timedOutFutures) {
1040 // Remove from messageIdToTimeout map.
1041 messageIdToTimeout_.erase(timeoutMetadata.messageId);
1042 }
1043 lock.unlock();
1044
1045 // Set an error on futures added to the timedOutFutures vector. We do this
1046 // outside the lock to prevent potential lock-order-inversions by callbacks
1047 // triggered by the setError call.
1048 for (auto& timeoutMetadata : timedOutFutures) {
1049 std::string errorMsg =
1050 fmt::format(kRpcTimeoutErrorStr, timeoutMetadata.timeout.count());
1051 auto err = makeRPCError(errorMsg, RPCErrorType::TIMEOUT);
1052 markFutureWithError(
1053 std::move(timeoutMetadata.responseFuture), std::move(err));
1054 }
1055 }
1056 }
1057
leaveGroup()1058 void TensorPipeAgent::leaveGroup() {
1059 std::unique_lock<std::mutex> lock(callCountMutex_);
1060 // local worker ActiveCallCount is 0 at this point and we will shutdown
1061 // (any future calls will be dropped)
1062 callCountCV_.wait(lock, [this] { return clientActiveCalls_ == 0; });
1063
1064 // Remove this agent's WorkerInfo from store
1065 removeCurrentName(rankToNameStore_, workerInfo_.id_, workerInfo_.name_);
1066
1067 // Set internal variable to be used during destructor
1068 shuttingDown_ = true;
1069 }
1070
1071 // TODO: Remove join()
join(bool shutdown,float)1072 void TensorPipeAgent::join(bool shutdown, float /* unused */) {
1073 VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is joining";
1074 if (!isStaticGroup_) {
1075 leaveGroup();
1076 return;
1077 }
1078
1079 // This method behaves like a barrier, as it can only return once all workers
1080 // have no more requests pending, including "nested" requests (triggered from
1081 // within the remote code of another call) and "follow-up" requests (triggered
1082 // from the callback of a future).
1083 while (true) {
1084 {
1085 std::unique_lock<std::mutex> lock(callCountMutex_);
1086 // It is enough to wait for there to be no more active client calls, since
1087 // each server call corresponds to a client call for some other worker.
1088 callCountCV_.wait(lock, [this] { return clientActiveCalls_ == 0; });
1089
1090 // We'd like to immediately proceed with the allreduce, but it's a call
1091 // that may block for some time, as it waits for other workers to also
1092 // complete all their active client calls. While we call allreduce we must
1093 // hold the mutex, or else the count we send to other workers may get
1094 // stale (e.g., if some nested call happens in the meantime). But we can't
1095 // hold the lock for an indeterminately long time, as that would block
1096 // other operations (e.g., send). Thus we must release the lock and only
1097 // re-acquire it when all workers are ready to proceed with the allreduce.
1098 // We perform this synchronization using a barrier.
1099 }
1100 VLOG(1) << "RPC agent for " << workerInfo_.name_
1101 << " completed all client calls and is entering a barrier";
1102 syncCallCount(shutdownStore_, worldSize_);
1103 {
1104 std::unique_lock<std::mutex> lock(callCountMutex_);
1105 // At this point, the count may have become non-zero again. We can't wait
1106 // for those calls to complete as other workers are waiting for us in the
1107 // allreduce and we would block them. Thus we send our count even if it is
1108 // non-zero and if anyone (be it us or another worker) has a non-zero
1109 // count we'll just do another round.
1110 VLOG(1) << "RPC agent for " << workerInfo_.name_
1111 << " exited the barrier and found " << clientActiveCalls_
1112 << " active client calls";
1113 int totalClientActiveCalls =
1114 syncCallCount(shutdownStore_, worldSize_, clientActiveCalls_);
1115 VLOG(1) << "RPC agent for " << workerInfo_.name_
1116 << " completed sync call counts and got a total of "
1117 << totalClientActiveCalls
1118 << " active client calls across all workers";
1119 if (totalClientActiveCalls == 0) {
1120 if (shutdown) {
1121 shuttingDown_ = true;
1122 syncCallCount(shutdownStore_, worldSize_);
1123 }
1124 break;
1125 }
1126 }
1127 }
1128 VLOG(1) << "RPC agent for " << workerInfo_.name_ << " done joining";
1129 }
1130
shutdownImpl()1131 void TensorPipeAgent::shutdownImpl() {
1132 // FIXME Isn't it too verbose for a library to print logs in normal operation?
1133 LOG(INFO) << "RPC agent for " << workerInfo_.name_ << " is shutting down";
1134
1135 // Join the Timeout Thread
1136 timeoutThreadCV_.notify_one();
1137 if (timeoutThread_.joinable()) {
1138 timeoutThread_.join();
1139 }
1140 VLOG(1) << "RPC agent for " << workerInfo_.name_
1141 << " done waiting for timeout thread to join";
1142
1143 // This will close all the pipes and listeners, invoke all callbacks with
1144 // errors, turn down the I/O event loops and wait for everything to terminate.
1145 context_->join();
1146 VLOG(1) << "RPC agent for " << workerInfo_.name_
1147 << " done waiting for TensorPipe context to join";
1148
1149 // NOTE: We need to call waitWorkComplete in the end after we have shutdown
1150 // all listeners for Tensorpipe. This is to drain any already accepted work
1151 // in the ThreadPool. If this is done before we shutdown the listeners,
1152 // additional work could be added after this call and before we shutdown
1153 // listeners. This work would continue executing in the threadpool and might
1154 // cause issues during shutdown of the system.
1155 threadPool_.waitWorkComplete();
1156 VLOG(1) << "RPC agent for " << workerInfo_.name_
1157 << " done waiting for thread pool to complete work";
1158 }
1159
getWorkerInfo(const std::string & workerName) const1160 const WorkerInfo& TensorPipeAgent::getWorkerInfo(
1161 const std::string& workerName) const {
1162 std::unordered_map<std::string, WorkerInfo>::const_iterator it;
1163 {
1164 GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
1165 it = workerNameToInfo_.find(workerName);
1166 }
1167 TORCH_CHECK(
1168 it != workerNameToInfo_.end(),
1169 fmt::format(
1170 "name:{},rank:{} could not find destination name {}",
1171 workerInfo_.name_,
1172 workerInfo_.id_,
1173 workerName));
1174 return it->second;
1175 }
1176
getWorkerInfo(worker_id_t workerId) const1177 const WorkerInfo& TensorPipeAgent::getWorkerInfo(worker_id_t workerId) const {
1178 std::unordered_map<worker_id_t, WorkerInfo>::const_iterator it;
1179 {
1180 GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
1181 it = workerIdToInfo_.find(workerId);
1182 }
1183 TORCH_CHECK(
1184 it != workerIdToInfo_.end(),
1185 fmt::format(
1186 "name:{},rank:{} could not find destination id {}",
1187 workerInfo_.name_,
1188 workerInfo_.id_,
1189 workerId));
1190 return it->second;
1191 }
1192
getWorkerInfos() const1193 std::vector<WorkerInfo> TensorPipeAgent::getWorkerInfos() const {
1194 std::vector<WorkerInfo> workerInfos;
1195 workerInfos.reserve(workerNameToInfo_.size());
1196 for (auto& item : workerNameToInfo_) {
1197 workerInfos.emplace_back(item.second);
1198 }
1199 return workerInfos;
1200 }
1201
findWorkerURL(const WorkerInfo & worker) const1202 const std::string& TensorPipeAgent::findWorkerURL(
1203 const WorkerInfo& worker) const {
1204 std::unordered_map<std::string, std::string>::const_iterator it;
1205 {
1206 GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
1207 it = workerNameToURL_.find(worker.name_);
1208 }
1209 TORCH_CHECK(
1210 it != workerNameToURL_.end(),
1211 fmt::format(
1212 "name:{},rank:{} could not find destination url for name {}",
1213 workerInfo_.name_,
1214 workerInfo_.id_,
1215 worker.name_));
1216 return it->second;
1217 }
1218
updateGroupMembership(const WorkerInfo & workerInfo,const std::vector<c10::Device> & devices,const std::unordered_map<std::string,DeviceMap> & reverseDeviceMaps,bool isJoin)1219 void TensorPipeAgent::updateGroupMembership(
1220 const WorkerInfo& workerInfo,
1221 const std::vector<c10::Device>& devices,
1222 const std::unordered_map<std::string, DeviceMap>& reverseDeviceMaps,
1223 bool isJoin) {
1224 std::string name = workerInfo.name_;
1225 worker_id_t id = workerInfo.id_;
1226 // Rank with workerInfo is joining the group, update internal mappings
1227 if (isJoin) {
1228 GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
1229 workerIdToInfo_.emplace(id, workerInfo);
1230 workerNameToInfo_.emplace(name, workerInfo);
1231
1232 // TODO: we should get nodeAddrStr in the joining process, then pass in as
1233 // an argument rather than getting from store each time
1234 auto nodeAddrData = nameToAddressStore_.get(name);
1235 auto nodeAddrStr =
1236 std::string((const char*)nodeAddrData.data(), nodeAddrData.size());
1237 workerNameToURL_.insert({name, nodeAddrStr});
1238
1239 for (const auto& it : reverseDeviceMaps) {
1240 if (reverseDeviceMaps_.find(it.first) == reverseDeviceMaps_.end()) {
1241 reverseDeviceMaps_[it.first] = it.second;
1242 }
1243 }
1244 // TODO: clean up mutex for devices_ usage
1245 // Add devices that have not been added yet
1246 for (const auto& it : devices) {
1247 if (std::find(devices_.begin(), devices_.end(), it) == devices_.end()) {
1248 devices_.push_back(it);
1249 }
1250 }
1251 } else {
1252 workerIdToInfo_.erase(id);
1253 workerNameToInfo_.erase(name);
1254 workerNameToURL_.erase(name);
1255
1256 // remove reverse device maps that are no longer used
1257 for (auto it = reverseDeviceMaps_.begin();
1258 it != reverseDeviceMaps_.end();) {
1259 if (reverseDeviceMaps.find(it->first) == reverseDeviceMaps.end()) {
1260 it = reverseDeviceMaps_.erase(it);
1261 } else {
1262 it++;
1263 }
1264 }
1265
1266 // remove devices that are no longer used
1267 for (auto it = devices_.begin(); it != devices_.end();) {
1268 if (std::find(devices.begin(), devices.end(), *it) == devices.end()) {
1269 it = devices_.erase(it);
1270 } else {
1271 it++;
1272 }
1273 }
1274 }
1275 }
getMetrics()1276 std::unordered_map<std::string, std::string> TensorPipeAgent::getMetrics() {
1277 std::unordered_map<std::string, std::string> metrics;
1278 metrics[kThreadPoolSize] = std::to_string(threadPool_.size());
1279 metrics[kNumIdleThreads] = std::to_string(threadPool_.numAvailable());
1280 {
1281 std::unique_lock<std::mutex> lock(callCountMutex_);
1282 metrics[kClientActiveCalls] = std::to_string(clientActiveCalls_);
1283 metrics[kServerActiveCalls] = std::to_string(serverActiveCalls_);
1284 metrics[kServerActiveAsyncCalls] = std::to_string(serverActiveAsyncCalls_);
1285 }
1286 if (isGILProfilingEnabled()) {
1287 {
1288 std::unique_lock<std::mutex> lock(metricsMutex_);
1289 // Include the averages for each time series metric. This is just the GIL
1290 // Wait Time for now.
1291 auto averageGilWaitTime =
1292 timeSeriesMetrics_[kGilAverageWaitTime].computeAverage();
1293 lock.unlock();
1294 metrics[kGilAverageWaitTime] = std::to_string(averageGilWaitTime);
1295 }
1296 }
1297
1298 return metrics;
1299 }
1300
addGilWaitTime(const std::chrono::microseconds gilWaitTime)1301 void TensorPipeAgent::addGilWaitTime(
1302 const std::chrono::microseconds gilWaitTime) {
1303 std::lock_guard<std::mutex> lock(metricsMutex_);
1304 timeSeriesMetrics_[kGilAverageWaitTime].addData(gilWaitTime.count());
1305 }
1306
getNetworkData()1307 TensorPipeAgent::NetworkDataDict TensorPipeAgent::getNetworkData() {
1308 std::lock_guard<std::mutex> lock(networkDataMutex_);
1309 return networkData_;
1310 }
1311
getNetworkSourceInfo()1312 NetworkSourceInfo TensorPipeAgent::getNetworkSourceInfo() {
1313 NetworkSourceInfo info = {
1314 RpcAgent::getWorkerInfo().id_,
1315 nameToAddressStore_.get(RpcAgent::getWorkerInfo().name_)};
1316
1317 return info;
1318 }
1319
trackNetworkData(uint64_t requestSize,uint64_t responseSize,const std::string & destWorkerName)1320 void TensorPipeAgent::trackNetworkData(
1321 uint64_t requestSize,
1322 uint64_t responseSize,
1323 const std::string& destWorkerName) {
1324 std::lock_guard<std::mutex> lock(networkDataMutex_);
1325 networkData_[destWorkerName].numCalls++;
1326 networkData_[destWorkerName].totalSentBytes += requestSize;
1327 networkData_[destWorkerName].totalRecvBytes += responseSize;
1328 }
1329
trackNetworkError(uint64_t requestSize,const std::string & destWorkerName)1330 void TensorPipeAgent::trackNetworkError(
1331 uint64_t requestSize,
1332 const std::string& destWorkerName) {
1333 std::lock_guard<std::mutex> lock(networkDataMutex_);
1334 networkData_[destWorkerName].numCalls++;
1335 networkData_[destWorkerName].totalSentBytes += requestSize;
1336 networkData_[destWorkerName].totalErrors++;
1337 }
1338
increaseCallCount(int32_t & count)1339 void TensorPipeAgent::increaseCallCount(int32_t& count) {
1340 {
1341 std::unique_lock<std::mutex> lock(callCountMutex_);
1342 ++count;
1343 }
1344 callCountCV_.notify_all();
1345 }
1346
decreaseCallCount(int32_t & count)1347 void TensorPipeAgent::decreaseCallCount(int32_t& count) {
1348 {
1349 std::unique_lock<std::mutex> lock(callCountMutex_);
1350 --count;
1351 }
1352 callCountCV_.notify_all();
1353 }
1354
markFutureAsComplete(std::shared_ptr<AtomicJitFuture> atomicFuture,c10::intrusive_ptr<Message> message,std::vector<c10::Stream> streams)1355 void TensorPipeAgent::markFutureAsComplete(
1356 std::shared_ptr<AtomicJitFuture> atomicFuture,
1357 c10::intrusive_ptr<Message> message,
1358 std::vector<c10::Stream> streams) {
1359 if (!atomicFuture->isComplete.test_and_set()) {
1360 // Completing the future will run its callbacks, which could execute
1361 // arbitrary user code. To prevent blocking or stalling the TensorPipe event
1362 // loops, we defer this to a worker thread.
1363 threadPool_.run([this,
1364 atomicFuture{std::move(atomicFuture)},
1365 message{std::move(message)},
1366 streams{std::move(streams)}]() mutable {
1367 c10::MultiStreamGuard guard(streams);
1368 std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> storages =
1369 message->getStorages();
1370 atomicFuture->jitFuture->markCompleted(
1371 std::move(message), std::move(storages));
1372 // The future's callbacks may schedule further RPCs, increasing the count.
1373 // Thus we must decrease it after completing the future, otherwise it may
1374 // briefly dip to zero and trick join into thinking all work is done.
1375 decreaseCallCount(clientActiveCalls_);
1376 });
1377 }
1378 }
1379
markFutureWithError(std::shared_ptr<AtomicJitFuture> atomicFuture,std::string errorMsg)1380 void TensorPipeAgent::markFutureWithError(
1381 std::shared_ptr<AtomicJitFuture> atomicFuture,
1382 std::string errorMsg) {
1383 if (!atomicFuture->isComplete.test_and_set()) {
1384 // Completing the future will run its callbacks, which could execute
1385 // arbitrary user code. To prevent blocking or stalling the TensorPipe event
1386 // loops, we defer this to a worker thread.
1387 threadPool_.run([this,
1388 atomicFuture{std::move(atomicFuture)},
1389 errorMsg{std::move(errorMsg)}]() mutable {
1390 atomicFuture->jitFuture->setError(
1391 std::make_exception_ptr(std::runtime_error(errorMsg)));
1392 // The future's callbacks may schedule further RPCs, increasing the count.
1393 // Thus we must decrease it after completing the future, otherwise it may
1394 // briefly dip to zero and trick join into thinking all work is done.
1395 decreaseCallCount(clientActiveCalls_);
1396 });
1397 }
1398 }
1399
getDevicesForRemote(const std::string & remoteName,const Message & message) const1400 std::vector<c10::Device> TensorPipeAgent::getDevicesForRemote(
1401 const std::string& remoteName,
1402 const Message& message) const {
1403 std::unordered_map<std::string, DeviceMap> deviceMaps;
1404 {
1405 GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
1406 deviceMaps = message.isRequest() ? opts_.deviceMaps : reverseDeviceMaps_;
1407 }
1408
1409 const auto errStr = c10::str(
1410 "TensorPipe RPC backend only supports CPU tensors by default, please "
1411 "move your tensors to CPU before sending them over RPC, or call "
1412 "`set_device_map` on `TensorPipeRpcBackendOptions` to explicitly "
1413 "configure device mapping. ",
1414 message.isRequest() ? "Request" : "Response",
1415 " device mapping is not available for destination ",
1416 remoteName);
1417
1418 const auto& iter = deviceMaps.find(remoteName);
1419 if (iter == deviceMaps.end()) {
1420 for (const auto& t : message.tensors()) {
1421 TORCH_CHECK(
1422 t.device().is_cpu(),
1423 errStr,
1424 ", but found tensor on device: ",
1425 t.device());
1426 }
1427 return {};
1428 } else {
1429 return getDevicesForTensors(message.tensors(), iter->second, errStr);
1430 }
1431 }
1432
getDeviceMap(const WorkerInfo & dst) const1433 DeviceMap TensorPipeAgent::getDeviceMap(const WorkerInfo& dst) const {
1434 auto it = opts_.deviceMaps.find(dst.name_);
1435 if (it == opts_.deviceMaps.end()) {
1436 return {};
1437 }
1438 return it->second;
1439 }
1440
getStore() const1441 const c10::intrusive_ptr<::c10d::Store> TensorPipeAgent::getStore() const {
1442 return store_;
1443 }
1444
getBackendOptions() const1445 TensorPipeRpcBackendOptions TensorPipeAgent::getBackendOptions() const {
1446 return opts_;
1447 }
1448
getDevices() const1449 const std::vector<c10::Device>& TensorPipeAgent::getDevices() const {
1450 GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
1451 return devices_;
1452 }
1453
timeoutMapSize()1454 size_t TensorPipeAgent::timeoutMapSize() {
1455 std::unique_lock<std::mutex> lock(timeoutMapMutex_);
1456 return timeoutMap_.size();
1457 }
1458
numPendingResponses()1459 size_t TensorPipeAgent::numPendingResponses() {
1460 std::unique_lock<std::mutex> lock(callCountMutex_);
1461 return clientActiveCalls_;
1462 }
1463
messageIdToTimeoutMapSize()1464 size_t TensorPipeAgent::messageIdToTimeoutMapSize() {
1465 std::unique_lock<std::mutex> lock(timeoutMapMutex_);
1466 return messageIdToTimeout_.size();
1467 }
1468
1469 } // namespace torch::distributed::rpc
1470
1471 #endif // USE_TENSORPIPE
1472