1 #pragma once 2 3 #ifdef USE_TENSORPIPE 4 5 #include <torch/csrc/distributed/rpc/message.h> 6 #include <torch/csrc/distributed/rpc/tensorpipe_agent.h> 7 8 namespace torch { 9 namespace distributed { 10 namespace rpc { 11 12 struct TORCH_API FaultyTensorPipeRpcBackendOptions 13 : public TensorPipeRpcBackendOptions { 14 FaultyTensorPipeRpcBackendOptions( 15 int num_worker_threads, 16 float rpc_timeout, 17 std::string init_method, 18 std::vector<std::string> messages_to_fail, 19 std::unordered_map<std::string, float> messages_to_delay, 20 int num_fail_sends = 0) TensorPipeRpcBackendOptionsFaultyTensorPipeRpcBackendOptions21 : TensorPipeRpcBackendOptions( 22 num_worker_threads, 23 std::optional<std::vector<std::string>>(), 24 std::optional<std::vector<std::string>>(), 25 rpc_timeout, 26 std::move(init_method)), 27 messagesToFail(std::move(messages_to_fail)), 28 messagesToDelay(std::move(messages_to_delay)), 29 numFailSends(num_fail_sends) { 30 TORCH_CHECK(numFailSends >= 0, "numFailSends should be non-negative"); 31 } 32 33 std::vector<std::string> messagesToFail; 34 std::unordered_map<std::string, float> messagesToDelay; 35 int numFailSends; 36 }; 37 38 class TORCH_API FaultyTensorPipeAgent : public TensorPipeAgent { 39 public: 40 FaultyTensorPipeAgent( 41 const c10::intrusive_ptr<::c10d::Store>& store, 42 std::string selfName, 43 worker_id_t selfId, 44 int worldSize, 45 FaultyTensorPipeRpcBackendOptions opts, 46 std::unordered_map<std::string, DeviceMap> reverseDeviceMaps, 47 std::vector<c10::Device> devices, 48 std::unique_ptr<RequestCallback> callback); 49 50 // Faulty send function for this class. 51 c10::intrusive_ptr<JitFuture> send( 52 const WorkerInfo& to, 53 c10::intrusive_ptr<Message> message, 54 const float rpcTimeoutSeconds = torch::distributed::rpc::kUnsetRpcTimeout, 55 const DeviceMap& deviceMap = {}) override; 56 57 // Add delay to writes 58 void pipeWrite( 59 const std::shared_ptr<tensorpipe::Pipe>& pipe, 60 c10::intrusive_ptr<Message> rpcMessage, 61 std::vector<c10::Device>&& devices, 62 std::vector<c10::Stream> streams, 63 std::function<void(const tensorpipe::Error&)> fn) noexcept override; 64 65 protected: 66 // This function checks the messageTypesToFail_ to determine whether to use 67 // the faulty send or not. 68 bool shouldFailMessage(MessageType type) const; 69 70 private: 71 // This function parses the list of strings passed in by the python tests and 72 // resolves the Message Types that must use the faulty send. 73 std::vector<MessageType> parseMessagesToFailInput( 74 const std::vector<std::string>& messagesToFail) const; 75 76 // Returns amount of time in seconds to delay sending of the given message 77 // type. 78 float getDelayForMessage(MessageType type) const; 79 80 // Parse message types that we should inject arbitrary delays for. 81 std::unordered_map<MessageType, float, std::hash<int>> parseMessagesToDelay( 82 const std::unordered_map<std::string, float>& messageTypesToDelay) const; 83 84 // Number of sends to intentionally fail before allowing one to succeed. 85 const int numFailSends_; 86 87 // Vector of the MessageTypes that we must use the faulty send for. This is 88 // parsed based on a list of strings passed in by the python tests. 89 const std::vector<MessageType> messageTypesToFail_; 90 91 // Mapping of message types to amount we should delay send for in the ::send() 92 // function. 93 std::unordered_map<MessageType, float, std::hash<int>> messageTypesToDelay_; 94 95 // Map to track the number of sends we've failed for each RPC. 96 std::unordered_map<std::string, int> failMessageCountMap_; 97 98 // Mutex to guard failMessageCountMap_ 99 std::mutex failMapMutex_; 100 101 MessageType messageStringToType(const std::string& messageString) const; 102 }; 103 104 } // namespace rpc 105 } // namespace distributed 106 } // namespace torch 107 108 #endif // USE_TENSORPIPE 109