xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_TENSORPIPE
4 
5 #include <torch/csrc/distributed/rpc/message.h>
6 #include <torch/csrc/distributed/rpc/tensorpipe_agent.h>
7 
8 namespace torch {
9 namespace distributed {
10 namespace rpc {
11 
12 struct TORCH_API FaultyTensorPipeRpcBackendOptions
13     : public TensorPipeRpcBackendOptions {
14   FaultyTensorPipeRpcBackendOptions(
15       int num_worker_threads,
16       float rpc_timeout,
17       std::string init_method,
18       std::vector<std::string> messages_to_fail,
19       std::unordered_map<std::string, float> messages_to_delay,
20       int num_fail_sends = 0)
TensorPipeRpcBackendOptionsFaultyTensorPipeRpcBackendOptions21       : TensorPipeRpcBackendOptions(
22             num_worker_threads,
23             std::optional<std::vector<std::string>>(),
24             std::optional<std::vector<std::string>>(),
25             rpc_timeout,
26             std::move(init_method)),
27         messagesToFail(std::move(messages_to_fail)),
28         messagesToDelay(std::move(messages_to_delay)),
29         numFailSends(num_fail_sends) {
30     TORCH_CHECK(numFailSends >= 0, "numFailSends should be non-negative");
31   }
32 
33   std::vector<std::string> messagesToFail;
34   std::unordered_map<std::string, float> messagesToDelay;
35   int numFailSends;
36 };
37 
38 class TORCH_API FaultyTensorPipeAgent : public TensorPipeAgent {
39  public:
40   FaultyTensorPipeAgent(
41       const c10::intrusive_ptr<::c10d::Store>& store,
42       std::string selfName,
43       worker_id_t selfId,
44       int worldSize,
45       FaultyTensorPipeRpcBackendOptions opts,
46       std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
47       std::vector<c10::Device> devices,
48       std::unique_ptr<RequestCallback> callback);
49 
50   // Faulty send function for this class.
51   c10::intrusive_ptr<JitFuture> send(
52       const WorkerInfo& to,
53       c10::intrusive_ptr<Message> message,
54       const float rpcTimeoutSeconds = torch::distributed::rpc::kUnsetRpcTimeout,
55       const DeviceMap& deviceMap = {}) override;
56 
57   // Add delay to writes
58   void pipeWrite(
59       const std::shared_ptr<tensorpipe::Pipe>& pipe,
60       c10::intrusive_ptr<Message> rpcMessage,
61       std::vector<c10::Device>&& devices,
62       std::vector<c10::Stream> streams,
63       std::function<void(const tensorpipe::Error&)> fn) noexcept override;
64 
65  protected:
66   // This function checks the messageTypesToFail_ to determine whether to use
67   // the faulty send or not.
68   bool shouldFailMessage(MessageType type) const;
69 
70  private:
71   // This function parses the list of strings passed in by the python tests and
72   // resolves the Message Types that must use the faulty send.
73   std::vector<MessageType> parseMessagesToFailInput(
74       const std::vector<std::string>& messagesToFail) const;
75 
76   // Returns amount of time in seconds to delay sending of the given message
77   // type.
78   float getDelayForMessage(MessageType type) const;
79 
80   // Parse message types that we should inject arbitrary delays for.
81   std::unordered_map<MessageType, float, std::hash<int>> parseMessagesToDelay(
82       const std::unordered_map<std::string, float>& messageTypesToDelay) const;
83 
84   // Number of sends to intentionally fail before allowing one to succeed.
85   const int numFailSends_;
86 
87   // Vector of the MessageTypes that we must use the faulty send for. This is
88   // parsed based on a list of strings passed in by the python tests.
89   const std::vector<MessageType> messageTypesToFail_;
90 
91   // Mapping of message types to amount we should delay send for in the ::send()
92   // function.
93   std::unordered_map<MessageType, float, std::hash<int>> messageTypesToDelay_;
94 
95   // Map to track the number of sends we've failed for each RPC.
96   std::unordered_map<std::string, int> failMessageCountMap_;
97 
98   // Mutex to guard failMessageCountMap_
99   std::mutex failMapMutex_;
100 
101   MessageType messageStringToType(const std::string& messageString) const;
102 };
103 
104 } // namespace rpc
105 } // namespace distributed
106 } // namespace torch
107 
108 #endif // USE_TENSORPIPE
109