xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifdef USE_TENSORPIPE
2 
3 #include <torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.h>
4 #include <torch/csrc/distributed/rpc/utils.h>
5 
6 namespace torch {
7 namespace distributed {
8 namespace rpc {
9 
fromVecToString(const std::vector<char> & vec)10 static std::string fromVecToString(const std::vector<char>& vec) {
11   return std::string(vec.begin(), vec.end());
12 }
13 
FaultyTensorPipeAgent(const c10::intrusive_ptr<::c10d::Store> & store,std::string selfName,worker_id_t selfId,int worldSize,FaultyTensorPipeRpcBackendOptions opts,std::unordered_map<std::string,DeviceMap> reverseDeviceMaps,std::vector<c10::Device> devices,std::unique_ptr<RequestCallback> callback)14 FaultyTensorPipeAgent::FaultyTensorPipeAgent(
15     const c10::intrusive_ptr<::c10d::Store>& store,
16     std::string selfName,
17     worker_id_t selfId,
18     int worldSize,
19     FaultyTensorPipeRpcBackendOptions opts,
20     std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
21     std::vector<c10::Device> devices,
22     std::unique_ptr<RequestCallback> callback)
23     : TensorPipeAgent(
24           store,
25           std::move(selfName),
26           selfId,
27           worldSize,
28           std::move(opts),
29           std::move(reverseDeviceMaps),
30           std::move(devices),
31           std::move(callback)),
32       numFailSends_(opts.numFailSends),
33       messageTypesToFail_(parseMessagesToFailInput(opts.messagesToFail)),
34       messageTypesToDelay_(parseMessagesToDelay(opts.messagesToDelay)) {}
35 
parseMessagesToFailInput(const std::vector<std::string> & messagesToFail) const36 std::vector<MessageType> FaultyTensorPipeAgent::parseMessagesToFailInput(
37     const std::vector<std::string>& messagesToFail) const {
38   // Since we can only pass strings corresponding to the Message Types from the
39   // python tests, we must parse the list of strings and resolve the actual
40   // types. We will then check this list of types in the send function to
41   // determine whether we should fail or not.
42   std::vector<MessageType> messageTypesToFail;
43   messageTypesToFail.reserve(messagesToFail.size());
44   for (const auto& msgString : messagesToFail) {
45     messageTypesToFail.push_back(messageStringToType(msgString));
46   }
47   return messageTypesToFail;
48 }
49 
50 std::unordered_map<MessageType, float, std::hash<int>> FaultyTensorPipeAgent::
parseMessagesToDelay(const std::unordered_map<std::string,float> & messageTypesToDelay) const51     parseMessagesToDelay(const std::unordered_map<std::string, float>&
52                              messageTypesToDelay) const {
53   std::unordered_map<MessageType, float, std::hash<int>> delayMessages;
54   for (const auto& messagePair : messageTypesToDelay) {
55     float delay = messagePair.second;
56     TORCH_CHECK(
57         delay >= 0,
58         "Delays passed to FaultyTensorPipeAgent must be non-negative.")
59     delayMessages.insert({messageStringToType(messagePair.first), delay});
60   }
61   return delayMessages;
62 }
63 
send(const WorkerInfo & to,c10::intrusive_ptr<Message> message,const float rpcTimeoutSeconds,const DeviceMap &)64 c10::intrusive_ptr<JitFuture> FaultyTensorPipeAgent::send(
65     const WorkerInfo& to,
66     c10::intrusive_ptr<Message> message,
67     const float rpcTimeoutSeconds,
68     const DeviceMap& /* unused */) {
69   // We only fail control messages that have been specified by the test case.
70   // For all other messages, we just send them without any failures.
71   if (!shouldFailMessage(message->type())) {
72     return TensorPipeAgent::send(to, std::move(message), rpcTimeoutSeconds);
73   }
74 
75   // This send function checks the failMessageCountMap_ to check whether
76   // we must fail the next send. If the send must be failed, we set an error
77   // on the returned future immediately and increment the counter in the map,
78   // otherwise we just call the TensorPipeAgent send.
79   const auto key = fromVecToString(message->payload());
80   std::unique_lock<std::mutex> lock(failMapMutex_);
81   auto it = failMessageCountMap_.find(key);
82   if (it == failMessageCountMap_.end()) {
83     failMessageCountMap_[key] = 0;
84   }
85   if (failMessageCountMap_[key] < numFailSends_) {
86     failMessageCountMap_[key]++;
87     lock.unlock();
88     auto jitFuture = c10::make_intrusive<JitFuture>(at::AnyClassType::get());
89     jitFuture->setError(std::make_exception_ptr(std::runtime_error(makeRPCError(
90         c10::str("Send attempt failed intentionally for ", key),
91         RPCErrorType::INTENTIONAL_FAILURE))));
92     return jitFuture;
93   } else {
94     lock.unlock();
95     return TensorPipeAgent::send(to, std::move(message), rpcTimeoutSeconds);
96   }
97 }
98 
pipeWrite(const std::shared_ptr<tensorpipe::Pipe> & pipe,c10::intrusive_ptr<Message> rpcMessage,std::vector<c10::Device> && devices,std::vector<c10::Stream> streams,std::function<void (const tensorpipe::Error &)> fn)99 void FaultyTensorPipeAgent::pipeWrite(
100     const std::shared_ptr<tensorpipe::Pipe>& pipe,
101     c10::intrusive_ptr<Message> rpcMessage,
102     std::vector<c10::Device>&& devices,
103     std::vector<c10::Stream> streams,
104     std::function<void(const tensorpipe::Error&)> fn) noexcept {
105   float msgDelay = getDelayForMessage(rpcMessage->type());
106   if (msgDelay != 0) {
107     // Sleep for the specified delay for the message.
108     std::this_thread::sleep_for(std::chrono::milliseconds(
109         static_cast<int>(msgDelay * kSecToMsConversion)));
110   }
111   TensorPipeAgent::pipeWrite(pipe, rpcMessage, std::move(devices), streams, fn);
112 }
113 
shouldFailMessage(MessageType type) const114 bool FaultyTensorPipeAgent::shouldFailMessage(MessageType type) const {
115   // Return true if the input message type is in the messageTypesToFail_ list
116   return (
117       std::find(messageTypesToFail_.begin(), messageTypesToFail_.end(), type) !=
118       messageTypesToFail_.end());
119 }
120 
getDelayForMessage(MessageType type) const121 float FaultyTensorPipeAgent::getDelayForMessage(MessageType type) const {
122   const auto& it = messageTypesToDelay_.find(type);
123   return it == messageTypesToDelay_.end() ? 0 : it->second;
124 }
125 
messageStringToType(const std::string & messageString) const126 MessageType FaultyTensorPipeAgent::messageStringToType(
127     const std::string& messageString) const {
128   // Lazily constructed map that returns string to message type mapping
129   static std::unordered_map<std::string, MessageType> msgMap = {
130       {"RREF_FORK_REQUEST", MessageType::RREF_FORK_REQUEST},
131       {"RREF_CHILD_ACCEPT", MessageType::RREF_CHILD_ACCEPT},
132       {"RREF_USER_DELETE", MessageType::RREF_USER_DELETE},
133       {"CLEANUP_AUTOGRAD_CONTEXT_REQ",
134        MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ},
135       {"PYTHON_REMOTE_CALL", MessageType::PYTHON_REMOTE_CALL},
136       {"SCRIPT_REMOTE_CALL", MessageType::SCRIPT_REMOTE_CALL},
137       {"PYTHON_CALL", MessageType::PYTHON_CALL},
138       {"SCRIPT_CALL", MessageType::SCRIPT_CALL},
139       {"PYTHON_RREF_FETCH_CALL", MessageType::PYTHON_RREF_FETCH_CALL},
140       {"SCRIPT_RREF_FETCH_CALL", MessageType::SCRIPT_RREF_FETCH_CALL}};
141   const auto& it = msgMap.find(messageString);
142   TORCH_CHECK(
143       it != msgMap.end(),
144       "No mapping to rpc::MessageType exists for ",
145       messageString);
146   return it->second;
147 }
148 
149 } // namespace rpc
150 } // namespace distributed
151 } // namespace torch
152 
153 #endif // USE_TENSORPIPE
154