1 #pragma once 2 3 #include <torch/csrc/autograd/profiler.h> 4 #include <torch/csrc/distributed/rpc/message.h> 5 #include <torch/csrc/distributed/rpc/rpc_agent.h> 6 #include <torch/csrc/distributed/rpc/rpc_command_base.h> 7 #include <torch/csrc/distributed/rpc/types.h> 8 9 namespace torch { 10 namespace distributed { 11 namespace autograd { 12 class TORCH_API RpcWithProfilingResp : public rpc::RpcCommandBase { 13 public: 14 // For sending RPCs over the wire 15 RpcWithProfilingResp( 16 rpc::MessageType messageType, 17 c10::intrusive_ptr<rpc::Message> wrappedMessage, 18 std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents, 19 rpc::ProfilingId profilingId); 20 21 // For receiving RPCs. Used in from message when converting a message received 22 // over the wire. 23 RpcWithProfilingResp( 24 rpc::MessageType messageType, 25 std::unique_ptr<rpc::RpcCommandBase> wrappedRpc, 26 rpc::MessageType wrappedMessageType, 27 std::vector<torch::Tensor> tensors, 28 std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents, 29 rpc::ProfilingId profilingId); 30 c10::intrusive_ptr<rpc::Message> toMessageImpl() && override; 31 static std::unique_ptr<RpcWithProfilingResp> fromMessage( 32 const rpc::Message& message); 33 // Retrieve remote Events 34 std::vector<torch::autograd::profiler::LegacyEvent> getProfiledEvents() const; 35 // Retrieve the globally unique profiling ID corresponding to this command. 36 const rpc::ProfilingId& getProfilingId() const; 37 // Retrieve the original RPC which this ProfilingRPC wraps. 38 RpcCommandBase& wrappedRpc(); 39 // Destructively move the wrapped RPC. 40 std::unique_ptr<RpcCommandBase> moveWrappedRpc() &&; 41 // Message type of the wrapped RPC 42 rpc::MessageType wrappedMessageType() const; 43 // Set the wrapped RPC for this RPC. 44 void setWrappedRpc(std::unique_ptr<RpcCommandBase> wrappedRpc); 45 46 private: 47 // message type 48 const rpc::MessageType messageType_; 49 // wrapped message 50 c10::intrusive_ptr<rpc::Message> wrappedMessage_; 51 std::unique_ptr<RpcCommandBase> wrappedRpc_; 52 rpc::MessageType wrappedMessageType_; 53 std::vector<torch::Tensor> tensors_; 54 const std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents_; 55 const rpc::ProfilingId profilingId_; 56 }; 57 } // namespace autograd 58 } // namespace distributed 59 } // namespace torch 60