xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/tensorpipe_agent.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/rpc/tensorpipe_agent.h>
2 
3 #ifdef USE_TENSORPIPE
4 
5 #include <limits>
6 #include <tuple>
7 #include <utility>
8 
9 #include <fmt/format.h>
10 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated")
11 #include <tensorpipe/tensorpipe.h>
12 C10_DIAGNOSTIC_POP()
13 
14 #include <torch/csrc/distributed/rpc/agent_utils.h>
15 #include <torch/csrc/distributed/rpc/tensorpipe_utils.h>
16 #include <torch/csrc/distributed/rpc/utils.h>
17 
18 #include <c10/core/StreamGuard.h>
19 #include <c10/util/irange.h>
20 
21 namespace torch::distributed::rpc {
22 
23 namespace {
24 
25 // An environment variable along the lines of GLOO_ and NCCL_SOCKET_IFNAME that
26 // allows the user to specify a device to bind to, instead of binding to the
27 // address that the hostname resolves to.
28 const std::string kSocketIfnameEnvVar = "TP_SOCKET_IFNAME";
29 const std::string kDefaultUvAddress = "127.0.0.1";
30 
31 const std::string kGilAverageWaitTime = "agent.gil_average_wait_time_us";
32 const std::string kThreadPoolSize = "agent.thread_pool_size";
33 const std::string kNumIdleThreads = "agent.num_idle_threads";
34 const std::string kClientActiveCalls = "agent.client_active_calls";
35 const std::string kServerActiveCalls = "agent.server_active_calls";
36 const std::string kServerActiveAsyncCalls = "agent.server_active_async_calls";
37 
getDevicesForTensors(const std::vector<torch::Tensor> & tensors,const DeviceMap & deviceMap,const std::string & remoteName)38 std::vector<c10::Device> getDevicesForTensors(
39     const std::vector<torch::Tensor>& tensors,
40     const DeviceMap& deviceMap,
41     const std::string& remoteName) {
42   // If the deviceMap is overridden, use that instead.
43   const auto errStr = c10::str(
44       "TensorPipe RPC backend only supports CPU tensors by default, please "
45       "move your tensors to CPU before sending them over RPC, or call "
46       "`set_device_map` on `TensorPipeRpcBackendOptions` to explicitly "
47       "configure device mapping. ",
48       "Request device mapping is not available for destination ",
49       remoteName);
50   std::vector<c10::Device> devices;
51   devices.reserve(tensors.size());
52   bool hasMappedDevice = false;
53   for (const auto& t : tensors) {
54     if (t.device().is_cpu()) {
55       const auto deviceIter = deviceMap.find(c10::kCPU);
56       if (deviceIter == deviceMap.end()) {
57         devices.emplace_back(c10::kCPU);
58       } else {
59         devices.emplace_back(deviceIter->second);
60         hasMappedDevice = true;
61       }
62     } else {
63       const auto deviceIter = deviceMap.find(t.device());
64       TORCH_CHECK(
65           deviceIter != deviceMap.end(),
66           errStr,
67           " for device ",
68           t.device(),
69           " but received a tensor on that device.");
70       devices.push_back(deviceIter->second);
71       hasMappedDevice = true;
72     }
73   }
74   if (!hasMappedDevice) {
75     devices.clear();
76   }
77   return devices;
78 }
79 
getStreamsFromPoolForDevices(const std::vector<c10::Device> & devices)80 std::vector<c10::Stream> getStreamsFromPoolForDevices(
81     const std::vector<c10::Device>& devices) {
82   if (devices.empty()) {
83     return {};
84   }
85   c10::impl::VirtualGuardImpl impl(devices[0].type());
86   std::vector<c10::Stream> streams;
87   streams.reserve(devices.size());
88   for (const c10::Device& device : devices) {
89     TORCH_INTERNAL_ASSERT(device.type() == impl.type());
90     streams.push_back(impl.getStreamFromGlobalPool(device));
91   }
92   return streams;
93 }
94 
getCurrentStreamsForDevices(const std::vector<c10::Device> & devices)95 std::vector<c10::Stream> getCurrentStreamsForDevices(
96     const std::vector<c10::Device>& devices) {
97   if (devices.empty()) {
98     return {};
99   }
100   c10::impl::VirtualGuardImpl impl(devices[0].type());
101   std::vector<c10::Stream> streams;
102   streams.reserve(devices.size());
103   for (const c10::Device& device : devices) {
104     TORCH_INTERNAL_ASSERT(device.type() == impl.type());
105     streams.push_back(impl.getStream(device));
106   }
107   return streams;
108 }
109 
getDevicesOfTensors(const std::vector<torch::Tensor> & tensors)110 std::vector<c10::Device> getDevicesOfTensors(
111     const std::vector<torch::Tensor>& tensors) {
112   std::optional<c10::impl::VirtualGuardImpl> impl;
113   size_t deviceCount = 0;
114   std::vector<bool> indexBitset;
115   for (const torch::Tensor& tensor : tensors) {
116     if (!tensor.is_cpu()) {
117       c10::Device device = tensor.device();
118       if (!impl.has_value()) {
119         impl.emplace(device.type());
120         indexBitset.resize(impl->deviceCount());
121       }
122       TORCH_INTERNAL_ASSERT(device.type() == impl->type());
123       TORCH_INTERNAL_ASSERT(device.has_index());
124       if (!indexBitset[device.index()]) {
125         deviceCount++;
126         indexBitset[device.index()] = true;
127       }
128     }
129   }
130   std::vector<c10::Device> devices;
131   devices.reserve(deviceCount);
132   for (const auto idx : c10::irange(indexBitset.size())) {
133     if (indexBitset[idx]) {
134       devices.emplace_back(impl->type(), static_cast<c10::DeviceIndex>(idx));
135     }
136   }
137   return devices;
138 }
139 
makeStreamsWaitOnOthers(const std::vector<c10::Stream> & consumers,const std::vector<c10::Stream> & producers)140 void makeStreamsWaitOnOthers(
141     const std::vector<c10::Stream>& consumers,
142     const std::vector<c10::Stream>& producers) {
143   for (const c10::Stream& producer : producers) {
144     const c10::Stream& consumer =
145         getStreamForDevice(consumers, producer.device());
146     c10::Event event(producer.device_type());
147     event.record(producer);
148     event.block(consumer);
149   }
150 }
151 
152 } // namespace
153 
154 C10_DEFINE_REGISTRY_WITHOUT_WARNING(
155     TensorPipeTransportRegistry,
156     TransportRegistration);
157 
158 C10_DEFINE_REGISTRY_WITHOUT_WARNING(
159     TensorPipeChannelRegistry,
160     ChannelRegistration);
161 
guessAddress()162 const std::string& TensorPipeAgent::guessAddress() {
163   static const std::string uvAddress = []() {
164     char* ifnameEnv = std::getenv(kSocketIfnameEnvVar.c_str());
165     if (ifnameEnv != nullptr) {
166       auto [error, result] =
167           tensorpipe::transport::uv::lookupAddrForIface(ifnameEnv);
168       if (error) {
169         LOG(WARNING) << "Failed to look up the IP address for interface "
170                      << ifnameEnv << " (" << error.what() << "), defaulting to "
171                      << kDefaultUvAddress;
172         return kDefaultUvAddress;
173       }
174       return result;
175     }
176     auto [error, result] = tensorpipe::transport::uv::lookupAddrForHostname();
177     if (error) {
178       LOG(WARNING) << "Failed to look up the IP address for the hostname ("
179                    << error.what() << "), defaulting to " << kDefaultUvAddress;
180       return kDefaultUvAddress;
181     }
182     return result;
183   }();
184   return uvAddress;
185 }
186 
187 namespace {
188 
makeUvTransport()189 std::unique_ptr<TransportRegistration> makeUvTransport() {
190   auto context = tensorpipe::transport::uv::create();
191   std::string address = TensorPipeAgent::guessAddress();
192   return std::make_unique<TransportRegistration>(TransportRegistration{
193       std::move(context), kUvTransportPriority, std::move(address)});
194 }
195 
196 // The UV transport is implemented using standard TCP connections. It leverages
197 // libuv (https://github.com/libuv/libuv) in order to be cross-platform.
198 C10_REGISTER_CREATOR(TensorPipeTransportRegistry, uv, makeUvTransport);
199 
200 #if TENSORPIPE_HAS_SHM_TRANSPORT
201 
makeShmTransport()202 std::unique_ptr<TransportRegistration> makeShmTransport() {
203   auto context = tensorpipe::transport::shm::create();
204   return std::make_unique<TransportRegistration>(
205       TransportRegistration{std::move(context), kShmTransportPriority, ""});
206 }
207 
208 // The SHM implements connections using ringbuffers residing in anonymous shared
209 // memory (plus UNIX domain sockets to bootstrap the connection and exchange
210 // file descriptors). It is Linux-only due to some advanced features (O_TMPFILE,
211 // eventfd, ...).
212 C10_REGISTER_CREATOR(TensorPipeTransportRegistry, shm, makeShmTransport);
213 
214 #endif // TENSORPIPE_HAS_SHM_TRANSPORT
215 
216 #if TENSORPIPE_HAS_IBV_TRANSPORT
217 
makeIbvTransport()218 std::unique_ptr<TransportRegistration> makeIbvTransport() {
219   auto context = tensorpipe::transport::ibv::create();
220   std::string address = TensorPipeAgent::guessAddress();
221   return std::make_unique<TransportRegistration>(TransportRegistration{
222       std::move(context), kIbvTransportPriority, std::move(address)});
223 }
224 
225 // The IBV transport sends data across using an InfiniBand queue pair, locally
226 // copying data to and from a staging buffer (registered with libibverbs) and
227 // issuing a RDMA write for transferring data across machines (plus a send for
228 // acknowledging it). It bootstraps using a standard TCP connection to exchange
229 // setup information. It is Linux-only.
230 C10_REGISTER_CREATOR(TensorPipeTransportRegistry, ibv, makeIbvTransport);
231 
232 #endif // TENSORPIPE_HAS_IBV_TRANSPORT
233 
makeBasicChannel()234 std::unique_ptr<ChannelRegistration> makeBasicChannel() {
235   auto context = tensorpipe::channel::basic::create();
236   return std::make_unique<ChannelRegistration>(
237       ChannelRegistration{std::move(context), kBasicChannelPriority});
238 }
239 
240 // The basic channel is just a straightforward adapter wrapper that allows any
241 // transport to be used as a channel.
242 C10_REGISTER_CREATOR(TensorPipeChannelRegistry, basic, makeBasicChannel);
243 
244 #if TENSORPIPE_HAS_CMA_CHANNEL
245 
makeCmaChannel()246 std::unique_ptr<ChannelRegistration> makeCmaChannel() {
247   auto context = tensorpipe::channel::cma::create();
248   return std::make_unique<ChannelRegistration>(
249       ChannelRegistration{std::move(context), kCmaChannelPriority});
250 }
251 
252 // The CMA channel uses the Linux cross-memory attach syscalls (process_vm_readv
253 // and _writev), which allow one process to access the private memory of another
254 // process (as long as they belong to the same user and other security
255 // constraints are satisfied). It does, more or less, what GDB does when it's
256 // attached to a running process.
257 C10_REGISTER_CREATOR(TensorPipeChannelRegistry, cma, makeCmaChannel);
258 
259 #endif // TENSORPIPE_HAS_CMA_CHANNEL
260 
261 constexpr static int kNumUvThreads = 16;
262 
makeMultiplexedUvChannel()263 std::unique_ptr<ChannelRegistration> makeMultiplexedUvChannel() {
264   std::vector<std::shared_ptr<tensorpipe::transport::Context>> contexts;
265   std::vector<std::shared_ptr<tensorpipe::transport::Listener>> listeners;
266   for (const auto laneIdx C10_UNUSED : c10::irange(kNumUvThreads)) {
267     auto context = tensorpipe::transport::uv::create();
268     std::string address = TensorPipeAgent::guessAddress();
269     contexts.push_back(std::move(context));
270     listeners.push_back(contexts.back()->listen(address));
271   }
272   auto context = tensorpipe::channel::mpt::create(
273       std::move(contexts), std::move(listeners));
274   return std::make_unique<ChannelRegistration>(
275       ChannelRegistration{std::move(context), kMultiplexedUvChannelPriority});
276 }
277 
278 // The multiplexed UV channel encapsulates multiple UV transports (each with its
279 // own event loop thread). Each channel will, in turn, contain multiple UV
280 // connections, one for each of those contexts. When sending a tensor, its data
281 // is split in equal chunks and each chunks is sent on a different connection
282 // and thus driven by a different thread. This is needed to reach very high
283 // bandwidths.
284 C10_REGISTER_CREATOR(
285     TensorPipeChannelRegistry,
286     mpt_uv,
287     makeMultiplexedUvChannel);
288 
289 } // namespace
290 
291 //////////////////////////  MetricsTracker  /////////////////////////////////
292 
TimeSeriesMetricsTracker(uint64_t currentSum,uint64_t currentCount)293 TensorPipeAgent::TimeSeriesMetricsTracker::TimeSeriesMetricsTracker(
294     uint64_t currentSum,
295     uint64_t currentCount)
296     : currentSum_(currentSum), currentCount_(currentCount) {}
297 
addData(uint64_t dataPoint)298 void TensorPipeAgent::TimeSeriesMetricsTracker::addData(uint64_t dataPoint) {
299   currentSum_ += dataPoint;
300   ++currentCount_;
301 }
302 
computeAverage() const303 float TensorPipeAgent::TimeSeriesMetricsTracker::computeAverage() const {
304   return currentCount_ == 0 ? 0 : currentSum_ / (float)currentCount_;
305 }
306 
307 ////////////////////////  TensorpipeRpcAgent  /////////////////////////////////
308 
removeFromTimeoutMap(uint64_t messageId)309 void TensorPipeAgent::removeFromTimeoutMap(uint64_t messageId) {
310   // Remove entry from timeoutMap_.
311   {
312     std::unique_lock<std::mutex> lock(timeoutMapMutex_);
313     auto it = messageIdToTimeout_.find(messageId);
314     if (it == messageIdToTimeout_.end()) {
315       // Already removed from the map by pollTimeoutRpcs(), no need to
316       // process further.
317       return;
318     }
319 
320     auto& expirationTime = it->second;
321 
322     auto& timedOutFuturesVector = timeoutMap_[expirationTime];
323     for (auto it = timedOutFuturesVector.begin();
324          it != timedOutFuturesVector.end();
325          it++) {
326       if (it->messageId == messageId) {
327         it = timedOutFuturesVector.erase(it);
328         break;
329       }
330     }
331 
332     if (timedOutFuturesVector.empty()) {
333       timeoutMap_.erase(expirationTime);
334     }
335 
336     // Remove from messageId to timeout map as well.
337     messageIdToTimeout_.erase(messageId);
338   }
339 }
340 
prepareNames(bool isStaticGroup)341 void TensorPipeAgent::prepareNames(bool isStaticGroup) {
342   std::unordered_map<std::string, worker_id_t> nameToId;
343   if (isStaticGroup) {
344     nameToId = collectNames(
345         rankToNameStore_, workerInfo_.id_, workerInfo_.name_, worldSize_);
346   } else {
347     nameToId = collectCurrentNames(
348         rankToNameStore_, workerInfo_.id_, workerInfo_.name_);
349   }
350 
351   for (const auto& entry : nameToId) {
352     const auto& workerName = entry.first;
353     const auto& workerId = entry.second;
354     workerIdToInfo_.emplace(workerId, WorkerInfo(workerName, workerId));
355     workerNameToInfo_.emplace(workerName, WorkerInfo(workerName, workerId));
356   }
357 }
358 
checkAndSetStaticGroup(const c10::intrusive_ptr<::c10d::Store> & store)359 void TensorPipeAgent::checkAndSetStaticGroup(
360     const c10::intrusive_ptr<::c10d::Store>& store) {
361   std::string isStaticGroupKey("rpcIsStaticGroup");
362 
363   std::string isStaticGroupStr = isStaticGroup_ ? "true" : "false";
364   std::vector<uint8_t> isStaticGroupVec(
365       (uint8_t*)isStaticGroupStr.c_str(),
366       (uint8_t*)isStaticGroupStr.c_str() + isStaticGroupStr.length());
367   std::vector<uint8_t> returnedVec;
368   returnedVec = store->compareSet(
369       isStaticGroupKey, std::vector<uint8_t>(), isStaticGroupVec);
370   std::string returnedVal = std::string(returnedVec.begin(), returnedVec.end());
371   // In both cases, the returned value should be the value of isStaticGroupStr,
372   // otherwise there is a discrepency with initialization among one of the
373   // members
374   TORCH_CHECK(
375       returnedVal == isStaticGroupStr,
376       fmt::format(
377           "RPC group mixes statically and dynamically initialized members which is not supported. ",
378           "Static group property is initialized as {} and is trying to be set as {} ",
379           isStaticGroup_,
380           returnedVal));
381 }
382 
TensorPipeAgent(const c10::intrusive_ptr<::c10d::Store> & store,std::string selfName,worker_id_t selfId,std::optional<int> worldSize,TensorPipeRpcBackendOptions opts,std::unordered_map<std::string,DeviceMap> reverseDeviceMaps,std::vector<c10::Device> devices,std::unique_ptr<RequestCallback> cb)383 TensorPipeAgent::TensorPipeAgent(
384     const c10::intrusive_ptr<::c10d::Store>& store,
385     std::string selfName,
386     worker_id_t selfId,
387     std::optional<int> worldSize,
388     TensorPipeRpcBackendOptions opts,
389     std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
390     std::vector<c10::Device> devices,
391     std::unique_ptr<RequestCallback> cb)
392     : RpcAgent(
393           WorkerInfo(std::move(selfName), selfId),
394           std::move(cb),
395           std::chrono::milliseconds(
396               (long)(opts.rpcTimeoutSeconds * kSecToMsConversion))),
397       isStaticGroup_(worldSize.has_value()),
398       store_(store),
399       opts_(std::move(opts)),
400       reverseDeviceMaps_(std::move(reverseDeviceMaps)),
401       devices_(std::move(devices)),
402       threadPool_(opts_.numWorkerThreads),
403       context_(std::make_shared<tensorpipe::Context>(
404           tensorpipe::ContextOptions().name(workerInfo_.name_))),
405       rankToNameStore_("names", store),
406       nameToAddressStore_("addrs", store),
407       shutdownStore_("shutdown", store) {
408   if (isStaticGroup_) {
409     worldSize_ = worldSize.value();
410   }
411 
412   // check the static group attribute against store
413   checkAndSetStaticGroup(store);
414 
415   // collect worker names
416   prepareNames(isStaticGroup_);
417 
418   // Initialize the time-series metrics tracking map
419   timeSeriesMetrics_.emplace(kGilAverageWaitTime, TimeSeriesMetricsTracker());
420 }
421 
~TensorPipeAgent()422 TensorPipeAgent::~TensorPipeAgent() {
423   VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is being destroyed";
424   shutdown();
425 }
426 
startImpl()427 void TensorPipeAgent::startImpl() {
428   VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is starting";
429 
430   std::vector<std::string> addresses;
431   int lowestPriority = std::numeric_limits<int>::max();
432   std::string lowestPriorityTransport;
433 
434   // Register transports
435   for (auto& key : TensorPipeTransportRegistry()->Keys()) {
436     int64_t priority = -1;
437     if (opts_.transports.has_value()) {
438       auto iter =
439           std::find(opts_.transports->begin(), opts_.transports->end(), key);
440       if (iter == opts_.transports->end()) {
441         continue;
442       }
443       // Assign priorities in reverse order of occurrence in the vector, so that
444       // a transport that comes before another receives a higher priority.
445       priority =
446           opts_.transports->size() - 1 - (iter - opts_.transports->begin());
447     }
448     std::unique_ptr<TransportRegistration> reg =
449         TensorPipeTransportRegistry()->Create(key);
450     if (!reg->transport->isViable()) {
451       continue;
452     }
453     if (priority == -1) {
454       priority = reg->priority;
455     }
456     if (priority < lowestPriority) {
457       lowestPriority = priority;
458       lowestPriorityTransport = key;
459     }
460     addresses.push_back(c10::str(key, "://", reg->address));
461     context_->registerTransport(priority, key, reg->transport);
462   }
463 
464   // Register channels
465   for (auto& key : TensorPipeChannelRegistry()->Keys()) {
466     int64_t priority = -1;
467     if (opts_.channels.has_value()) {
468       auto iter =
469           std::find(opts_.channels->begin(), opts_.channels->end(), key);
470       if (iter == opts_.channels->end()) {
471         continue;
472       }
473       // Assign priorities in reverse order of occurrence in the vector, so
474       // that a channel that comes before another receives a higher priority.
475       priority = opts_.channels->size() - 1 - (iter - opts_.channels->begin());
476     }
477     std::unique_ptr<ChannelRegistration> reg =
478         TensorPipeChannelRegistry()->Create(key);
479     if (!reg->channel->isViable()) {
480       continue;
481     }
482     if (priority == -1) {
483       priority = reg->priority;
484     }
485     context_->registerChannel(priority, key, reg->channel);
486   }
487 
488   listener_ = context_->listen(addresses);
489 
490   // Store our own url.
491   const auto address = listener_->url(lowestPriorityTransport);
492   nameToAddressStore_.set(workerInfo_.name_, address);
493 
494   VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is using address "
495           << address;
496 
497   for (const auto& p : workerNameToInfo_) {
498     const auto& name = p.first;
499     auto nodeAddrData = nameToAddressStore_.get(name);
500     auto nodeAddrStr =
501         std::string((const char*)nodeAddrData.data(), nodeAddrData.size());
502     workerNameToURL_.insert({name, nodeAddrStr});
503   }
504 
505   // Start the Timeout Thread
506   timeoutThread_ = std::thread(&TensorPipeAgent::pollTimeoutRpcs, this);
507 
508   listener_->accept([this](
509                         const tensorpipe::Error& error,
510                         std::shared_ptr<tensorpipe::Pipe> pipe) {
511     onListenerAccepted(error, pipe);
512   });
513 }
514 
onListenerAccepted(const tensorpipe::Error & error,std::shared_ptr<tensorpipe::Pipe> & pipe)515 void TensorPipeAgent::onListenerAccepted(
516     const tensorpipe::Error& error,
517     std::shared_ptr<tensorpipe::Pipe>& pipe) {
518   if (error) {
519     if (error.isOfType<tensorpipe::ListenerClosedError>() &&
520         !rpcAgentRunning_.load()) {
521       // This is expected.
522     } else {
523       LOG(WARNING) << "RPC agent for " << workerInfo_.name_
524                    << " encountered error when accepting incoming pipe: "
525                    << error.what();
526     }
527     return;
528   }
529 
530   // Accept the next connection request
531   listener_->accept([this](
532                         const tensorpipe::Error& error,
533                         std::shared_ptr<tensorpipe::Pipe> pipe) {
534     onListenerAccepted(error, pipe);
535   });
536 
537   VLOG(1) << "RPC agent for " << workerInfo_.name_
538           << " accepted incoming pipe from " << pipe->getRemoteName();
539 
540   // Arm for server read
541   respond(pipe);
542 }
543 
pipeRead(const std::shared_ptr<tensorpipe::Pipe> & pipe,std::function<void (const tensorpipe::Error &,c10::intrusive_ptr<Message>,std::vector<c10::Stream>)> fn)544 void TensorPipeAgent::pipeRead(
545     const std::shared_ptr<tensorpipe::Pipe>& pipe,
546     std::function<void(
547         const tensorpipe::Error&,
548         c10::intrusive_ptr<Message>,
549         std::vector<c10::Stream>)> fn) noexcept {
550   pipe->readDescriptor([this, fn{std::move(fn)}, pipe](
551                            const tensorpipe::Error& error,
552                            tensorpipe::Descriptor tpDescriptor) mutable {
553     if (error) {
554       fn(error, c10::intrusive_ptr<Message>(), {});
555       return;
556     }
557 
558     std::vector<c10::Stream> streams;
559     {
560       GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
561       streams = getStreamsFromPoolForDevices(devices_);
562     }
563     auto [tpAllocation, tpBuffers] = tensorpipeAllocate(tpDescriptor, streams);
564 
565     pipe->read(
566         std::move(tpAllocation),
567         [tpDescriptor{std::move(tpDescriptor)},
568          tpBuffers{
569              std::make_shared<TensorpipeReadBuffers>(std::move(tpBuffers))},
570          fn{std::move(fn)},
571          streams{std::move(streams)}](const tensorpipe::Error& error) mutable {
572           if (error) {
573             fn(error, c10::intrusive_ptr<Message>(), {});
574             return;
575           }
576 
577           // FIXME This does some unpickling, which could be a bit expensive:
578           // perhaps it would be best to perform it inside the worker threads?
579           c10::intrusive_ptr<Message> rpcMessage = tensorpipeDeserialize(
580               std::move(tpDescriptor), std::move(*tpBuffers));
581 
582           fn(error, std::move(rpcMessage), std::move(streams));
583         });
584   });
585 }
586 
pipeWrite(const std::shared_ptr<tensorpipe::Pipe> & pipe,c10::intrusive_ptr<Message> rpcMessage,std::vector<c10::Device> && devices,std::vector<c10::Stream> streams,std::function<void (const tensorpipe::Error &)> fn)587 void TensorPipeAgent::pipeWrite(
588     const std::shared_ptr<tensorpipe::Pipe>& pipe,
589     c10::intrusive_ptr<Message> rpcMessage,
590     std::vector<c10::Device>&& devices,
591     std::vector<c10::Stream> streams,
592     std::function<void(const tensorpipe::Error&)> fn) noexcept {
593   auto [tpMessage, tpBuffers] =
594       tensorpipeSerialize(std::move(rpcMessage), std::move(devices), streams);
595 
596   pipe->write(
597       std::move(tpMessage),
598       [tpBuffers{
599            std::make_shared<TensorpipeWriteBuffers>(std::move(tpBuffers))},
600        fn{std::move(fn)},
601        streams{std::move(streams)}](const tensorpipe::Error& error) {
602         fn(error);
603       });
604 }
605 
sendCompletedResponseMessage(std::shared_ptr<tensorpipe::Pipe> & pipe,JitFuture & futureResponseMessage,uint64_t messageId,std::vector<c10::Stream> streams)606 void TensorPipeAgent::sendCompletedResponseMessage(
607     std::shared_ptr<tensorpipe::Pipe>& pipe,
608     JitFuture& futureResponseMessage,
609     uint64_t messageId,
610     std::vector<c10::Stream> streams) {
611   if (!rpcAgentRunning_.load()) {
612     LOG(WARNING) << "RPC agent for " << workerInfo_.name_
613                  << " won't send response to request #" << messageId << " to "
614                  << pipe->getRemoteName() << ", as the agent is shutting down";
615     return;
616   }
617 
618   VLOG(1) << "RPC agent for " << workerInfo_.name_
619           << " is sending response to request #" << messageId << " to "
620           << pipe->getRemoteName();
621 
622   if (!futureResponseMessage.hasError()) {
623     c10::intrusive_ptr<Message> responseMessage =
624         futureResponseMessage.value().toCustomClass<Message>();
625     responseMessage->setId(messageId);
626 
627     std::vector<c10::Device> devices;
628     try {
629       devices = getDevicesForRemote(pipe->getRemoteName(), *responseMessage);
630     } catch (const std::exception& e) {
631       responseMessage = createExceptionResponse(e.what(), messageId);
632     }
633 
634     for (const auto& tensor : responseMessage->tensors()) {
635       const auto device = tensor.device();
636       if (!device.is_cpu()) {
637         GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
638         if (std::find(devices_.begin(), devices_.end(), device) ==
639             devices_.end()) {
640           std::ostringstream oss;
641           std::copy(
642               devices_.begin(),
643               devices_.end(),
644               std::ostream_iterator<c10::Device>(oss, ", "));
645           responseMessage = createExceptionResponse(
646               c10::str(
647                   "RPC detected that a user-function output tensor on device ",
648                   device,
649                   ". This device is not one of the input tensor devices: ",
650                   oss.str(),
651                   "which is not yet supported. Please file a feature request "
652                   "issue in PyTorch GitHub repo."),
653               messageId);
654           break;
655         }
656       }
657     }
658 
659     pipeWrite(
660         pipe,
661         std::move(responseMessage),
662         std::move(devices),
663         std::move(streams),
664         [this, pipe, messageId](const tensorpipe::Error& error) {
665           if (error) {
666             LOG(WARNING)
667                 << "RPC agent for " << workerInfo_.name_
668                 << " encountered error when sending response to request #"
669                 << messageId << " to " << pipe->getRemoteName() << ": "
670                 << error.what();
671             return;
672           }
673 
674           VLOG(1) << "RPC agent for " << workerInfo_.name_
675                   << " done sending response to request #" << messageId
676                   << " to " << pipe->getRemoteName();
677         });
678   } else {
679     pipeWrite(
680         pipe,
681         createExceptionResponse(
682             futureResponseMessage.tryRetrieveErrorMessage(), messageId),
683         /* devices */ {},
684         std::move(streams),
685         [this, pipe, messageId](const tensorpipe::Error& error) {
686           if (error) {
687             LOG(WARNING)
688                 << "RPC agent for " << workerInfo_.name_
689                 << " encountered error when sending response to request #"
690                 << messageId << " to " << pipe->getRemoteName() << ": "
691                 << error.what();
692             return;
693           }
694 
695           VLOG(1) << "RPC agent for " << workerInfo_.name_
696                   << " done sending response to request #" << messageId
697                   << " to " << pipe->getRemoteName();
698         });
699   }
700 }
701 
respond(std::shared_ptr<tensorpipe::Pipe> & pipe)702 void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
703   pipeRead(
704       pipe,
705       [this, pipe](
706           const tensorpipe::Error& error,
707           c10::intrusive_ptr<Message> requestMessage,
708           std::vector<c10::Stream> streams) mutable {
709         if (error) {
710           if (shuttingDown_) {
711             // This is expected.
712           } else {
713             LOG(WARNING)
714                 << "RPC agent for " << workerInfo_.name_
715                 << " encountered error when reading incoming request from "
716                 << pipe->getRemoteName() << ": " << error.what();
717           }
718           return;
719         }
720 
721         // Arm for next read
722         respond(pipe);
723 
724         uint64_t messageId = requestMessage->id();
725         increaseCallCount(serverActiveCalls_);
726 
727         VLOG(1) << "RPC agent for " << workerInfo_.name_
728                 << " received request #" << messageId << " from "
729                 << pipe->getRemoteName();
730 
731         // Defer user RPC UDF run to thread pool
732         threadPool_.run([this,
733                          pipe,
734                          messageId,
735                          requestMessage{std::move(requestMessage)},
736                          streams{std::move(streams)}]() mutable {
737           VLOG(1) << "RPC agent for " << workerInfo_.name_
738                   << " is running request #" << messageId << " from "
739                   << pipe->getRemoteName() << " in thread pool";
740 
741           c10::intrusive_ptr<JitFuture> futureResponseMessage;
742           try {
743             // Instead of creating a MultiStreamGuard here, the ctx is passed
744             // to the callback and the MultiStreamGuard is created there,
745             // because subsequent processing can switch threads due to 1)
746             // waiting for RRef arguments to become ready 2) async_execution.
747             // Besides, the `ctx` also needs to be propagated to
748             // `process***Call` methods to synchronize CUDA streams there
749             // to make sure that we fetch the correct value from `to_here()`
750             // call.
751             futureResponseMessage =
752                 cb_->operator()(*requestMessage, std::move(streams));
753           } catch (const std::exception& /* unused */) {
754             futureResponseMessage =
755                 c10::make_intrusive<JitFuture>(at::AnyClassType::get());
756             futureResponseMessage->setError(std::current_exception());
757           }
758 
759           increaseCallCount(serverActiveAsyncCalls_);
760           futureResponseMessage->addCallback(
761               [this, pipe, messageId](
762                   JitFuture& futureResponseMessage) mutable {
763                 decreaseCallCount(serverActiveCalls_);
764                 decreaseCallCount(serverActiveAsyncCalls_);
765                 auto streams = getCurrentStreamsForDevices(
766                     futureResponseMessage.devices());
767                 sendCompletedResponseMessage(
768                     pipe, futureResponseMessage, messageId, std::move(streams));
769               });
770 
771           VLOG(1) << "RPC agent for " << workerInfo_.name_
772                   << " done running request #" << messageId << " from "
773                   << pipe->getRemoteName() << " in thread pool";
774         });
775       });
776 }
777 
send(const WorkerInfo & toWorkerInfo,c10::intrusive_ptr<Message> requestMessage,const float rpcTimeoutSeconds,const DeviceMap & deviceMap)778 c10::intrusive_ptr<JitFuture> TensorPipeAgent::send(
779     const WorkerInfo& toWorkerInfo,
780     c10::intrusive_ptr<Message> requestMessage,
781     const float rpcTimeoutSeconds,
782     const DeviceMap& deviceMap) {
783   TORCH_CHECK(
784       requestMessage->isRequest(),
785       "TensorPipeAgent::send(..) is only for sending requests.");
786 
787   if (!rpcAgentRunning_.load()) {
788     auto err = c10::str(
789         "Node ",
790         RpcAgent::getWorkerInfo().id_,
791         "tried to send() a message of type ",
792         requestMessage->type(),
793         " but RPC is no longer running on this node.");
794     TORCH_CHECK(false, err);
795   }
796 
797   const auto& url = findWorkerURL(toWorkerInfo);
798 
799   decltype(connectedPipes_)::iterator it;
800   {
801     std::unique_lock<std::mutex> lock(connectedPipesMutex_);
802 
803     // See if we already have a connection to this address or not
804     it = connectedPipes_.find(toWorkerInfo.id_);
805     if (it == connectedPipes_.end()) {
806       // An instance of ClientPipe cannot be copied or moved as it contains a
807       // mutex, and to force in-place construction in GCC 5 we need piecewise
808       // construction in order to work around an issue.
809       it = connectedPipes_
810                .emplace(
811                    std::piecewise_construct,
812                    std::forward_as_tuple(toWorkerInfo.id_),
813                    std::forward_as_tuple(context_->connect(
814                        url,
815                        tensorpipe::PipeOptions().remoteName(
816                            toWorkerInfo.name_))))
817                .first;
818     }
819   }
820   ClientPipe& clientPipe = it->second;
821 
822   std::shared_ptr<torch::distributed::rpc::TensorPipeAgent::AtomicJitFuture>
823       futureResponseMessage;
824   {
825     GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
826     futureResponseMessage = std::make_shared<AtomicJitFuture>(devices_);
827   }
828   uint64_t messageId = nextMessageID_++;
829   requestMessage->setId(messageId);
830 
831   {
832     std::unique_lock<std::mutex> lock(clientPipe.mutex_);
833     clientPipe.pendingResponseMessage_[messageId] = futureResponseMessage;
834   }
835 
836   // Get devices for tensors in the request message. This can throw if device
837   // maps are not configured properly for this request.
838   std::vector<c10::Device> devices;
839   if (deviceMap.empty()) {
840     devices =
841         getDevicesForRemote(clientPipe.pipe_->getRemoteName(), *requestMessage);
842   } else {
843     // If deviceMap is specified, use that instead.
844     devices = getDevicesForTensors(
845         requestMessage->tensors(),
846         deviceMap,
847         clientPipe.pipe_->getRemoteName());
848   }
849 
850   futureResponseMessage->jitFuture->addCallback(
851       [this](JitFuture& /* unused */) {
852         TORCH_INTERNAL_ASSERT(
853             this->threadPool_.inThreadPool(),
854             "Future marked complete from outside the thread pool");
855       });
856 
857   increaseCallCount(clientActiveCalls_);
858   // Use the default RPC timeout if no timeout is specified for this send call
859   auto timeout = rpcTimeoutSeconds == kUnsetRpcTimeout
860       ? getRpcTimeout()
861       : std::chrono::milliseconds(
862             static_cast<int>(rpcTimeoutSeconds * kSecToMsConversion));
863 
864   // We only add to the timeoutMap_ if the timeout is not 0. Per our
865   // documentation, a user-provided timeout of 0 indicates the RPC should never
866   // expire (infinite timeout), so there is no need to track it in the
867   // timeoutMap_.
868   steady_clock_time_point expirationTime;
869   if (timeout.count() != 0) {
870     // Compute the expiration time for this message based on the timeout
871     expirationTime = computeRpcMessageExpiryTime(timeout);
872 
873     // Add the Future to the right vector in the timeoutMap_
874     {
875       std::unique_lock<std::mutex> lock(timeoutMapMutex_);
876       auto& timeoutFuturesVector = timeoutMap_[expirationTime];
877       messageIdToTimeout_.emplace(messageId, expirationTime);
878       timeoutFuturesVector.emplace_back(
879           messageId, futureResponseMessage, timeout);
880     }
881     timeoutThreadCV_.notify_one();
882   }
883 
884   VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is sending request #"
885           << messageId << " to " << clientPipe.pipe_->getRemoteName();
886 
887   std::vector<c10::Stream> streams;
888   {
889     GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
890     streams = getStreamsFromPoolForDevices(devices_);
891   }
892   makeStreamsWaitOnOthers(
893       streams,
894       getCurrentStreamsForDevices(
895           getDevicesOfTensors(requestMessage->tensors())));
896   pipeWrite(
897       clientPipe.pipe_,
898       std::move(requestMessage),
899       std::move(devices),
900       std::move(streams),
901       [this, &clientPipe, messageId](const tensorpipe::Error& error) mutable {
902         if (error) {
903           if (error.isOfType<tensorpipe::PipeClosedError>() &&
904               !rpcAgentRunning_.load()) {
905             // This is expected.
906           } else {
907             LOG(WARNING) << "RPC agent for " << workerInfo_.name_
908                          << " encountered error when sending outgoing request #"
909                          << messageId << " to "
910                          << clientPipe.pipe_->getRemoteName() << ": "
911                          << error.what();
912           }
913           handleClientError(clientPipe, error);
914           return;
915         }
916 
917         VLOG(1) << "RPC agent for " << workerInfo_.name_ << " sent request #"
918                 << messageId << " to " << clientPipe.pipe_->getRemoteName();
919 
920         pipeRead(
921             clientPipe.pipe_,
922             [this, &clientPipe](
923                 const tensorpipe::Error& error,
924                 c10::intrusive_ptr<Message> responseMessage,
925                 std::vector<c10::Stream> streams) {
926               if (error) {
927                 if (error.isOfType<tensorpipe::PipeClosedError>() &&
928                     !rpcAgentRunning_.load()) {
929                   // This is expected.
930                 } else {
931                   LOG(WARNING)
932                       << "RPC agent for " << workerInfo_.name_
933                       << " encountered error when reading incoming response from "
934                       << clientPipe.pipe_->getRemoteName() << ": "
935                       << error.what();
936                 }
937                 handleClientError(clientPipe, error);
938                 return;
939               }
940 
941               // Identify future response message by message ID
942               uint64_t messageId = responseMessage->id();
943 
944               VLOG(1) << "RPC agent for " << workerInfo_.name_
945                       << " received response #" << messageId << " from "
946                       << clientPipe.pipe_->getRemoteName();
947 
948               std::shared_ptr<AtomicJitFuture> futureResponseMessage;
949               {
950                 std::lock_guard<std::mutex> lock(clientPipe.mutex_);
951                 // A read error will lead all following callbacks to be
952                 // invoked with error, and shouldn't reach here.
953                 TORCH_INTERNAL_ASSERT(
954                     !clientPipe.inError_, "Shouldn't be in error state");
955                 auto it = clientPipe.pendingResponseMessage_.find(messageId);
956                 TORCH_INTERNAL_ASSERT(
957                     it != clientPipe.pendingResponseMessage_.end(),
958                     "message ID ",
959                     messageId,
960                     " is not recognized");
961                 futureResponseMessage = std::move(it->second);
962                 clientPipe.pendingResponseMessage_.erase(it);
963               }
964 
965               // Remove entry from timeoutMap_.
966               removeFromTimeoutMap(messageId);
967 
968               if (responseMessage->type() == MessageType::EXCEPTION) {
969                 markFutureWithError(
970                     std::move(futureResponseMessage),
971                     std::string(
972                         responseMessage->payload().begin(),
973                         responseMessage->payload().end()));
974               } else {
975                 markFutureAsComplete(
976                     std::move(futureResponseMessage),
977                     std::move(responseMessage),
978                     std::move(streams));
979               }
980             });
981       });
982 
983   return futureResponseMessage->jitFuture;
984 }
985 
handleClientError(ClientPipe & clientPipe,const tensorpipe::Error & error)986 void TensorPipeAgent::handleClientError(
987     ClientPipe& clientPipe,
988     const tensorpipe::Error& error) {
989   // When an error occurs on a pipe all pending operations will be aborted and
990   // all callbacks invoked with error, hence we immediately flush all future
991   // messages belonging to this pipe.
992   decltype(clientPipe.pendingResponseMessage_) pendingMsgs;
993   {
994     std::lock_guard<std::mutex> lock(clientPipe.mutex_);
995     std::swap(clientPipe.pendingResponseMessage_, pendingMsgs);
996     clientPipe.inError_ = true;
997   }
998   std::string errorMsg = error.what();
999   for (auto& p : pendingMsgs) {
1000     markFutureWithError(std::move(p.second), errorMsg);
1001 
1002     // Remove entry from timeoutMap_.
1003     removeFromTimeoutMap(p.first);
1004   }
1005 }
1006 
pollTimeoutRpcs()1007 void TensorPipeAgent::pollTimeoutRpcs() {
1008   while (rpcAgentRunning_.load()) {
1009     std::unique_lock<std::mutex> lock(timeoutMapMutex_);
1010 
1011     // We sleep until the earliest expiring RPC in the timeoutMap_. We must
1012     // also ensure that we sleep while the map is empty, and we exit sleeping
1013     // if the RPC Agent has been shutdown.
1014     for (;;) {
1015       if (!rpcAgentRunning_.load()) {
1016         return;
1017       }
1018 
1019       if (!timeoutMap_.empty()) {
1020         steady_clock_time_point earliestTimeout = timeoutMap_.begin()->first;
1021         if (std::chrono::steady_clock::now() >= earliestTimeout) {
1022           break;
1023         }
1024         timeoutThreadCV_.wait_until(lock, earliestTimeout);
1025       } else {
1026         timeoutThreadCV_.wait(lock);
1027       }
1028     }
1029 
1030     // Move all these futures to a separate vector so we can process them
1031     // outside the lock.
1032     std::vector<TimeoutMessageMetadata> timedOutFutures =
1033         std::move(timeoutMap_.begin()->second);
1034 
1035     // We can safely remove this key from the timeoutMap_ since all these
1036     // futures will be processed.
1037     timeoutMap_.erase(timeoutMap_.begin());
1038 
1039     for (auto& timeoutMetadata : timedOutFutures) {
1040       // Remove from messageIdToTimeout map.
1041       messageIdToTimeout_.erase(timeoutMetadata.messageId);
1042     }
1043     lock.unlock();
1044 
1045     // Set an error on futures added to the timedOutFutures vector. We do this
1046     // outside the lock to prevent potential lock-order-inversions by callbacks
1047     // triggered by the setError call.
1048     for (auto& timeoutMetadata : timedOutFutures) {
1049       std::string errorMsg =
1050           fmt::format(kRpcTimeoutErrorStr, timeoutMetadata.timeout.count());
1051       auto err = makeRPCError(errorMsg, RPCErrorType::TIMEOUT);
1052       markFutureWithError(
1053           std::move(timeoutMetadata.responseFuture), std::move(err));
1054     }
1055   }
1056 }
1057 
leaveGroup()1058 void TensorPipeAgent::leaveGroup() {
1059   std::unique_lock<std::mutex> lock(callCountMutex_);
1060   // local worker ActiveCallCount is 0 at this point and we will shutdown
1061   // (any future calls will be dropped)
1062   callCountCV_.wait(lock, [this] { return clientActiveCalls_ == 0; });
1063 
1064   // Remove this agent's WorkerInfo from store
1065   removeCurrentName(rankToNameStore_, workerInfo_.id_, workerInfo_.name_);
1066 
1067   // Set internal variable to be used during destructor
1068   shuttingDown_ = true;
1069 }
1070 
1071 // TODO: Remove join()
join(bool shutdown,float)1072 void TensorPipeAgent::join(bool shutdown, float /* unused */) {
1073   VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is joining";
1074   if (!isStaticGroup_) {
1075     leaveGroup();
1076     return;
1077   }
1078 
1079   // This method behaves like a barrier, as it can only return once all workers
1080   // have no more requests pending, including "nested" requests (triggered from
1081   // within the remote code of another call) and "follow-up" requests (triggered
1082   // from the callback of a future).
1083   while (true) {
1084     {
1085       std::unique_lock<std::mutex> lock(callCountMutex_);
1086       // It is enough to wait for there to be no more active client calls, since
1087       // each server call corresponds to a client call for some other worker.
1088       callCountCV_.wait(lock, [this] { return clientActiveCalls_ == 0; });
1089 
1090       // We'd like to immediately proceed with the allreduce, but it's a call
1091       // that may block for some time, as it waits for other workers to also
1092       // complete all their active client calls. While we call allreduce we must
1093       // hold the mutex, or else the count we send to other workers may get
1094       // stale (e.g., if some nested call happens in the meantime). But we can't
1095       // hold the lock for an indeterminately long time, as that would block
1096       // other operations (e.g., send). Thus we must release the lock and only
1097       // re-acquire it when all workers are ready to proceed with the allreduce.
1098       // We perform this synchronization using a barrier.
1099     }
1100     VLOG(1) << "RPC agent for " << workerInfo_.name_
1101             << " completed all client calls and is entering a barrier";
1102     syncCallCount(shutdownStore_, worldSize_);
1103     {
1104       std::unique_lock<std::mutex> lock(callCountMutex_);
1105       // At this point, the count may have become non-zero again. We can't wait
1106       // for those calls to complete as other workers are waiting for us in the
1107       // allreduce and we would block them. Thus we send our count even if it is
1108       // non-zero and if anyone (be it us or another worker) has a non-zero
1109       // count we'll just do another round.
1110       VLOG(1) << "RPC agent for " << workerInfo_.name_
1111               << " exited the barrier and found " << clientActiveCalls_
1112               << " active client calls";
1113       int totalClientActiveCalls =
1114           syncCallCount(shutdownStore_, worldSize_, clientActiveCalls_);
1115       VLOG(1) << "RPC agent for " << workerInfo_.name_
1116               << " completed sync call counts and got a total of "
1117               << totalClientActiveCalls
1118               << " active client calls across all workers";
1119       if (totalClientActiveCalls == 0) {
1120         if (shutdown) {
1121           shuttingDown_ = true;
1122           syncCallCount(shutdownStore_, worldSize_);
1123         }
1124         break;
1125       }
1126     }
1127   }
1128   VLOG(1) << "RPC agent for " << workerInfo_.name_ << " done joining";
1129 }
1130 
shutdownImpl()1131 void TensorPipeAgent::shutdownImpl() {
1132   // FIXME Isn't it too verbose for a library to print logs in normal operation?
1133   LOG(INFO) << "RPC agent for " << workerInfo_.name_ << " is shutting down";
1134 
1135   // Join the Timeout Thread
1136   timeoutThreadCV_.notify_one();
1137   if (timeoutThread_.joinable()) {
1138     timeoutThread_.join();
1139   }
1140   VLOG(1) << "RPC agent for " << workerInfo_.name_
1141           << " done waiting for timeout thread to join";
1142 
1143   // This will close all the pipes and listeners, invoke all callbacks with
1144   // errors, turn down the I/O event loops and wait for everything to terminate.
1145   context_->join();
1146   VLOG(1) << "RPC agent for " << workerInfo_.name_
1147           << " done waiting for TensorPipe context to join";
1148 
1149   // NOTE: We need to call waitWorkComplete in the end after we have shutdown
1150   // all listeners for Tensorpipe. This is to drain any already accepted work
1151   // in the ThreadPool. If this is done before we shutdown the listeners,
1152   // additional work could be added after this call and before we shutdown
1153   // listeners. This work would continue executing in the threadpool and might
1154   // cause issues during shutdown of the system.
1155   threadPool_.waitWorkComplete();
1156   VLOG(1) << "RPC agent for " << workerInfo_.name_
1157           << " done waiting for thread pool to complete work";
1158 }
1159 
getWorkerInfo(const std::string & workerName) const1160 const WorkerInfo& TensorPipeAgent::getWorkerInfo(
1161     const std::string& workerName) const {
1162   std::unordered_map<std::string, WorkerInfo>::const_iterator it;
1163   {
1164     GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
1165     it = workerNameToInfo_.find(workerName);
1166   }
1167   TORCH_CHECK(
1168       it != workerNameToInfo_.end(),
1169       fmt::format(
1170           "name:{},rank:{} could not find destination name {}",
1171           workerInfo_.name_,
1172           workerInfo_.id_,
1173           workerName));
1174   return it->second;
1175 }
1176 
getWorkerInfo(worker_id_t workerId) const1177 const WorkerInfo& TensorPipeAgent::getWorkerInfo(worker_id_t workerId) const {
1178   std::unordered_map<worker_id_t, WorkerInfo>::const_iterator it;
1179   {
1180     GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
1181     it = workerIdToInfo_.find(workerId);
1182   }
1183   TORCH_CHECK(
1184       it != workerIdToInfo_.end(),
1185       fmt::format(
1186           "name:{},rank:{} could not find destination id {}",
1187           workerInfo_.name_,
1188           workerInfo_.id_,
1189           workerId));
1190   return it->second;
1191 }
1192 
getWorkerInfos() const1193 std::vector<WorkerInfo> TensorPipeAgent::getWorkerInfos() const {
1194   std::vector<WorkerInfo> workerInfos;
1195   workerInfos.reserve(workerNameToInfo_.size());
1196   for (auto& item : workerNameToInfo_) {
1197     workerInfos.emplace_back(item.second);
1198   }
1199   return workerInfos;
1200 }
1201 
findWorkerURL(const WorkerInfo & worker) const1202 const std::string& TensorPipeAgent::findWorkerURL(
1203     const WorkerInfo& worker) const {
1204   std::unordered_map<std::string, std::string>::const_iterator it;
1205   {
1206     GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
1207     it = workerNameToURL_.find(worker.name_);
1208   }
1209   TORCH_CHECK(
1210       it != workerNameToURL_.end(),
1211       fmt::format(
1212           "name:{},rank:{} could not find destination url for name {}",
1213           workerInfo_.name_,
1214           workerInfo_.id_,
1215           worker.name_));
1216   return it->second;
1217 }
1218 
updateGroupMembership(const WorkerInfo & workerInfo,const std::vector<c10::Device> & devices,const std::unordered_map<std::string,DeviceMap> & reverseDeviceMaps,bool isJoin)1219 void TensorPipeAgent::updateGroupMembership(
1220     const WorkerInfo& workerInfo,
1221     const std::vector<c10::Device>& devices,
1222     const std::unordered_map<std::string, DeviceMap>& reverseDeviceMaps,
1223     bool isJoin) {
1224   std::string name = workerInfo.name_;
1225   worker_id_t id = workerInfo.id_;
1226   // Rank with workerInfo is joining the group, update internal mappings
1227   if (isJoin) {
1228     GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
1229     workerIdToInfo_.emplace(id, workerInfo);
1230     workerNameToInfo_.emplace(name, workerInfo);
1231 
1232     // TODO: we should get nodeAddrStr in the joining process, then pass in as
1233     // an argument rather than getting from store each time
1234     auto nodeAddrData = nameToAddressStore_.get(name);
1235     auto nodeAddrStr =
1236         std::string((const char*)nodeAddrData.data(), nodeAddrData.size());
1237     workerNameToURL_.insert({name, nodeAddrStr});
1238 
1239     for (const auto& it : reverseDeviceMaps) {
1240       if (reverseDeviceMaps_.find(it.first) == reverseDeviceMaps_.end()) {
1241         reverseDeviceMaps_[it.first] = it.second;
1242       }
1243     }
1244     // TODO: clean up mutex for devices_ usage
1245     // Add devices that have not been added yet
1246     for (const auto& it : devices) {
1247       if (std::find(devices_.begin(), devices_.end(), it) == devices_.end()) {
1248         devices_.push_back(it);
1249       }
1250     }
1251   } else {
1252     workerIdToInfo_.erase(id);
1253     workerNameToInfo_.erase(name);
1254     workerNameToURL_.erase(name);
1255 
1256     // remove reverse device maps that are no longer used
1257     for (auto it = reverseDeviceMaps_.begin();
1258          it != reverseDeviceMaps_.end();) {
1259       if (reverseDeviceMaps.find(it->first) == reverseDeviceMaps.end()) {
1260         it = reverseDeviceMaps_.erase(it);
1261       } else {
1262         it++;
1263       }
1264     }
1265 
1266     // remove devices that are no longer used
1267     for (auto it = devices_.begin(); it != devices_.end();) {
1268       if (std::find(devices.begin(), devices.end(), *it) == devices.end()) {
1269         it = devices_.erase(it);
1270       } else {
1271         it++;
1272       }
1273     }
1274   }
1275 }
getMetrics()1276 std::unordered_map<std::string, std::string> TensorPipeAgent::getMetrics() {
1277   std::unordered_map<std::string, std::string> metrics;
1278   metrics[kThreadPoolSize] = std::to_string(threadPool_.size());
1279   metrics[kNumIdleThreads] = std::to_string(threadPool_.numAvailable());
1280   {
1281     std::unique_lock<std::mutex> lock(callCountMutex_);
1282     metrics[kClientActiveCalls] = std::to_string(clientActiveCalls_);
1283     metrics[kServerActiveCalls] = std::to_string(serverActiveCalls_);
1284     metrics[kServerActiveAsyncCalls] = std::to_string(serverActiveAsyncCalls_);
1285   }
1286   if (isGILProfilingEnabled()) {
1287     {
1288       std::unique_lock<std::mutex> lock(metricsMutex_);
1289       // Include the averages for each time series metric. This is just the GIL
1290       // Wait Time for now.
1291       auto averageGilWaitTime =
1292           timeSeriesMetrics_[kGilAverageWaitTime].computeAverage();
1293       lock.unlock();
1294       metrics[kGilAverageWaitTime] = std::to_string(averageGilWaitTime);
1295     }
1296   }
1297 
1298   return metrics;
1299 }
1300 
addGilWaitTime(const std::chrono::microseconds gilWaitTime)1301 void TensorPipeAgent::addGilWaitTime(
1302     const std::chrono::microseconds gilWaitTime) {
1303   std::lock_guard<std::mutex> lock(metricsMutex_);
1304   timeSeriesMetrics_[kGilAverageWaitTime].addData(gilWaitTime.count());
1305 }
1306 
getNetworkData()1307 TensorPipeAgent::NetworkDataDict TensorPipeAgent::getNetworkData() {
1308   std::lock_guard<std::mutex> lock(networkDataMutex_);
1309   return networkData_;
1310 }
1311 
getNetworkSourceInfo()1312 NetworkSourceInfo TensorPipeAgent::getNetworkSourceInfo() {
1313   NetworkSourceInfo info = {
1314       RpcAgent::getWorkerInfo().id_,
1315       nameToAddressStore_.get(RpcAgent::getWorkerInfo().name_)};
1316 
1317   return info;
1318 }
1319 
trackNetworkData(uint64_t requestSize,uint64_t responseSize,const std::string & destWorkerName)1320 void TensorPipeAgent::trackNetworkData(
1321     uint64_t requestSize,
1322     uint64_t responseSize,
1323     const std::string& destWorkerName) {
1324   std::lock_guard<std::mutex> lock(networkDataMutex_);
1325   networkData_[destWorkerName].numCalls++;
1326   networkData_[destWorkerName].totalSentBytes += requestSize;
1327   networkData_[destWorkerName].totalRecvBytes += responseSize;
1328 }
1329 
trackNetworkError(uint64_t requestSize,const std::string & destWorkerName)1330 void TensorPipeAgent::trackNetworkError(
1331     uint64_t requestSize,
1332     const std::string& destWorkerName) {
1333   std::lock_guard<std::mutex> lock(networkDataMutex_);
1334   networkData_[destWorkerName].numCalls++;
1335   networkData_[destWorkerName].totalSentBytes += requestSize;
1336   networkData_[destWorkerName].totalErrors++;
1337 }
1338 
increaseCallCount(int32_t & count)1339 void TensorPipeAgent::increaseCallCount(int32_t& count) {
1340   {
1341     std::unique_lock<std::mutex> lock(callCountMutex_);
1342     ++count;
1343   }
1344   callCountCV_.notify_all();
1345 }
1346 
decreaseCallCount(int32_t & count)1347 void TensorPipeAgent::decreaseCallCount(int32_t& count) {
1348   {
1349     std::unique_lock<std::mutex> lock(callCountMutex_);
1350     --count;
1351   }
1352   callCountCV_.notify_all();
1353 }
1354 
markFutureAsComplete(std::shared_ptr<AtomicJitFuture> atomicFuture,c10::intrusive_ptr<Message> message,std::vector<c10::Stream> streams)1355 void TensorPipeAgent::markFutureAsComplete(
1356     std::shared_ptr<AtomicJitFuture> atomicFuture,
1357     c10::intrusive_ptr<Message> message,
1358     std::vector<c10::Stream> streams) {
1359   if (!atomicFuture->isComplete.test_and_set()) {
1360     // Completing the future will run its callbacks, which could execute
1361     // arbitrary user code. To prevent blocking or stalling the TensorPipe event
1362     // loops, we defer this to a worker thread.
1363     threadPool_.run([this,
1364                      atomicFuture{std::move(atomicFuture)},
1365                      message{std::move(message)},
1366                      streams{std::move(streams)}]() mutable {
1367       c10::MultiStreamGuard guard(streams);
1368       std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> storages =
1369           message->getStorages();
1370       atomicFuture->jitFuture->markCompleted(
1371           std::move(message), std::move(storages));
1372       // The future's callbacks may schedule further RPCs, increasing the count.
1373       // Thus we must decrease it after completing the future, otherwise it may
1374       // briefly dip to zero and trick join into thinking all work is done.
1375       decreaseCallCount(clientActiveCalls_);
1376     });
1377   }
1378 }
1379 
markFutureWithError(std::shared_ptr<AtomicJitFuture> atomicFuture,std::string errorMsg)1380 void TensorPipeAgent::markFutureWithError(
1381     std::shared_ptr<AtomicJitFuture> atomicFuture,
1382     std::string errorMsg) {
1383   if (!atomicFuture->isComplete.test_and_set()) {
1384     // Completing the future will run its callbacks, which could execute
1385     // arbitrary user code. To prevent blocking or stalling the TensorPipe event
1386     // loops, we defer this to a worker thread.
1387     threadPool_.run([this,
1388                      atomicFuture{std::move(atomicFuture)},
1389                      errorMsg{std::move(errorMsg)}]() mutable {
1390       atomicFuture->jitFuture->setError(
1391           std::make_exception_ptr(std::runtime_error(errorMsg)));
1392       // The future's callbacks may schedule further RPCs, increasing the count.
1393       // Thus we must decrease it after completing the future, otherwise it may
1394       // briefly dip to zero and trick join into thinking all work is done.
1395       decreaseCallCount(clientActiveCalls_);
1396     });
1397   }
1398 }
1399 
getDevicesForRemote(const std::string & remoteName,const Message & message) const1400 std::vector<c10::Device> TensorPipeAgent::getDevicesForRemote(
1401     const std::string& remoteName,
1402     const Message& message) const {
1403   std::unordered_map<std::string, DeviceMap> deviceMaps;
1404   {
1405     GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
1406     deviceMaps = message.isRequest() ? opts_.deviceMaps : reverseDeviceMaps_;
1407   }
1408 
1409   const auto errStr = c10::str(
1410       "TensorPipe RPC backend only supports CPU tensors by default, please "
1411       "move your tensors to CPU before sending them over RPC, or call "
1412       "`set_device_map` on `TensorPipeRpcBackendOptions` to explicitly "
1413       "configure device mapping. ",
1414       message.isRequest() ? "Request" : "Response",
1415       " device mapping is not available for destination ",
1416       remoteName);
1417 
1418   const auto& iter = deviceMaps.find(remoteName);
1419   if (iter == deviceMaps.end()) {
1420     for (const auto& t : message.tensors()) {
1421       TORCH_CHECK(
1422           t.device().is_cpu(),
1423           errStr,
1424           ", but found tensor on device: ",
1425           t.device());
1426     }
1427     return {};
1428   } else {
1429     return getDevicesForTensors(message.tensors(), iter->second, errStr);
1430   }
1431 }
1432 
getDeviceMap(const WorkerInfo & dst) const1433 DeviceMap TensorPipeAgent::getDeviceMap(const WorkerInfo& dst) const {
1434   auto it = opts_.deviceMaps.find(dst.name_);
1435   if (it == opts_.deviceMaps.end()) {
1436     return {};
1437   }
1438   return it->second;
1439 }
1440 
getStore() const1441 const c10::intrusive_ptr<::c10d::Store> TensorPipeAgent::getStore() const {
1442   return store_;
1443 }
1444 
getBackendOptions() const1445 TensorPipeRpcBackendOptions TensorPipeAgent::getBackendOptions() const {
1446   return opts_;
1447 }
1448 
getDevices() const1449 const std::vector<c10::Device>& TensorPipeAgent::getDevices() const {
1450   GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
1451   return devices_;
1452 }
1453 
timeoutMapSize()1454 size_t TensorPipeAgent::timeoutMapSize() {
1455   std::unique_lock<std::mutex> lock(timeoutMapMutex_);
1456   return timeoutMap_.size();
1457 }
1458 
numPendingResponses()1459 size_t TensorPipeAgent::numPendingResponses() {
1460   std::unique_lock<std::mutex> lock(callCountMutex_);
1461   return clientActiveCalls_;
1462 }
1463 
messageIdToTimeoutMapSize()1464 size_t TensorPipeAgent::messageIdToTimeoutMapSize() {
1465   std::unique_lock<std::mutex> lock(timeoutMapMutex_);
1466   return messageIdToTimeout_.size();
1467 }
1468 
1469 } // namespace torch::distributed::rpc
1470 
1471 #endif // USE_TENSORPIPE
1472