1 #ifdef USE_TENSORPIPE
2
3 #include <torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.h>
4 #include <torch/csrc/distributed/rpc/utils.h>
5
6 namespace torch {
7 namespace distributed {
8 namespace rpc {
9
fromVecToString(const std::vector<char> & vec)10 static std::string fromVecToString(const std::vector<char>& vec) {
11 return std::string(vec.begin(), vec.end());
12 }
13
FaultyTensorPipeAgent(const c10::intrusive_ptr<::c10d::Store> & store,std::string selfName,worker_id_t selfId,int worldSize,FaultyTensorPipeRpcBackendOptions opts,std::unordered_map<std::string,DeviceMap> reverseDeviceMaps,std::vector<c10::Device> devices,std::unique_ptr<RequestCallback> callback)14 FaultyTensorPipeAgent::FaultyTensorPipeAgent(
15 const c10::intrusive_ptr<::c10d::Store>& store,
16 std::string selfName,
17 worker_id_t selfId,
18 int worldSize,
19 FaultyTensorPipeRpcBackendOptions opts,
20 std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
21 std::vector<c10::Device> devices,
22 std::unique_ptr<RequestCallback> callback)
23 : TensorPipeAgent(
24 store,
25 std::move(selfName),
26 selfId,
27 worldSize,
28 std::move(opts),
29 std::move(reverseDeviceMaps),
30 std::move(devices),
31 std::move(callback)),
32 numFailSends_(opts.numFailSends),
33 messageTypesToFail_(parseMessagesToFailInput(opts.messagesToFail)),
34 messageTypesToDelay_(parseMessagesToDelay(opts.messagesToDelay)) {}
35
parseMessagesToFailInput(const std::vector<std::string> & messagesToFail) const36 std::vector<MessageType> FaultyTensorPipeAgent::parseMessagesToFailInput(
37 const std::vector<std::string>& messagesToFail) const {
38 // Since we can only pass strings corresponding to the Message Types from the
39 // python tests, we must parse the list of strings and resolve the actual
40 // types. We will then check this list of types in the send function to
41 // determine whether we should fail or not.
42 std::vector<MessageType> messageTypesToFail;
43 messageTypesToFail.reserve(messagesToFail.size());
44 for (const auto& msgString : messagesToFail) {
45 messageTypesToFail.push_back(messageStringToType(msgString));
46 }
47 return messageTypesToFail;
48 }
49
50 std::unordered_map<MessageType, float, std::hash<int>> FaultyTensorPipeAgent::
parseMessagesToDelay(const std::unordered_map<std::string,float> & messageTypesToDelay) const51 parseMessagesToDelay(const std::unordered_map<std::string, float>&
52 messageTypesToDelay) const {
53 std::unordered_map<MessageType, float, std::hash<int>> delayMessages;
54 for (const auto& messagePair : messageTypesToDelay) {
55 float delay = messagePair.second;
56 TORCH_CHECK(
57 delay >= 0,
58 "Delays passed to FaultyTensorPipeAgent must be non-negative.")
59 delayMessages.insert({messageStringToType(messagePair.first), delay});
60 }
61 return delayMessages;
62 }
63
send(const WorkerInfo & to,c10::intrusive_ptr<Message> message,const float rpcTimeoutSeconds,const DeviceMap &)64 c10::intrusive_ptr<JitFuture> FaultyTensorPipeAgent::send(
65 const WorkerInfo& to,
66 c10::intrusive_ptr<Message> message,
67 const float rpcTimeoutSeconds,
68 const DeviceMap& /* unused */) {
69 // We only fail control messages that have been specified by the test case.
70 // For all other messages, we just send them without any failures.
71 if (!shouldFailMessage(message->type())) {
72 return TensorPipeAgent::send(to, std::move(message), rpcTimeoutSeconds);
73 }
74
75 // This send function checks the failMessageCountMap_ to check whether
76 // we must fail the next send. If the send must be failed, we set an error
77 // on the returned future immediately and increment the counter in the map,
78 // otherwise we just call the TensorPipeAgent send.
79 const auto key = fromVecToString(message->payload());
80 std::unique_lock<std::mutex> lock(failMapMutex_);
81 auto it = failMessageCountMap_.find(key);
82 if (it == failMessageCountMap_.end()) {
83 failMessageCountMap_[key] = 0;
84 }
85 if (failMessageCountMap_[key] < numFailSends_) {
86 failMessageCountMap_[key]++;
87 lock.unlock();
88 auto jitFuture = c10::make_intrusive<JitFuture>(at::AnyClassType::get());
89 jitFuture->setError(std::make_exception_ptr(std::runtime_error(makeRPCError(
90 c10::str("Send attempt failed intentionally for ", key),
91 RPCErrorType::INTENTIONAL_FAILURE))));
92 return jitFuture;
93 } else {
94 lock.unlock();
95 return TensorPipeAgent::send(to, std::move(message), rpcTimeoutSeconds);
96 }
97 }
98
pipeWrite(const std::shared_ptr<tensorpipe::Pipe> & pipe,c10::intrusive_ptr<Message> rpcMessage,std::vector<c10::Device> && devices,std::vector<c10::Stream> streams,std::function<void (const tensorpipe::Error &)> fn)99 void FaultyTensorPipeAgent::pipeWrite(
100 const std::shared_ptr<tensorpipe::Pipe>& pipe,
101 c10::intrusive_ptr<Message> rpcMessage,
102 std::vector<c10::Device>&& devices,
103 std::vector<c10::Stream> streams,
104 std::function<void(const tensorpipe::Error&)> fn) noexcept {
105 float msgDelay = getDelayForMessage(rpcMessage->type());
106 if (msgDelay != 0) {
107 // Sleep for the specified delay for the message.
108 std::this_thread::sleep_for(std::chrono::milliseconds(
109 static_cast<int>(msgDelay * kSecToMsConversion)));
110 }
111 TensorPipeAgent::pipeWrite(pipe, rpcMessage, std::move(devices), streams, fn);
112 }
113
shouldFailMessage(MessageType type) const114 bool FaultyTensorPipeAgent::shouldFailMessage(MessageType type) const {
115 // Return true if the input message type is in the messageTypesToFail_ list
116 return (
117 std::find(messageTypesToFail_.begin(), messageTypesToFail_.end(), type) !=
118 messageTypesToFail_.end());
119 }
120
getDelayForMessage(MessageType type) const121 float FaultyTensorPipeAgent::getDelayForMessage(MessageType type) const {
122 const auto& it = messageTypesToDelay_.find(type);
123 return it == messageTypesToDelay_.end() ? 0 : it->second;
124 }
125
messageStringToType(const std::string & messageString) const126 MessageType FaultyTensorPipeAgent::messageStringToType(
127 const std::string& messageString) const {
128 // Lazily constructed map that returns string to message type mapping
129 static std::unordered_map<std::string, MessageType> msgMap = {
130 {"RREF_FORK_REQUEST", MessageType::RREF_FORK_REQUEST},
131 {"RREF_CHILD_ACCEPT", MessageType::RREF_CHILD_ACCEPT},
132 {"RREF_USER_DELETE", MessageType::RREF_USER_DELETE},
133 {"CLEANUP_AUTOGRAD_CONTEXT_REQ",
134 MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ},
135 {"PYTHON_REMOTE_CALL", MessageType::PYTHON_REMOTE_CALL},
136 {"SCRIPT_REMOTE_CALL", MessageType::SCRIPT_REMOTE_CALL},
137 {"PYTHON_CALL", MessageType::PYTHON_CALL},
138 {"SCRIPT_CALL", MessageType::SCRIPT_CALL},
139 {"PYTHON_RREF_FETCH_CALL", MessageType::PYTHON_RREF_FETCH_CALL},
140 {"SCRIPT_RREF_FETCH_CALL", MessageType::SCRIPT_RREF_FETCH_CALL}};
141 const auto& it = msgMap.find(messageString);
142 TORCH_CHECK(
143 it != msgMap.end(),
144 "No mapping to rpc::MessageType exists for ",
145 messageString);
146 return it->second;
147 }
148
149 } // namespace rpc
150 } // namespace distributed
151 } // namespace torch
152
153 #endif // USE_TENSORPIPE
154