1 #pragma once 2 3 #include <torch/csrc/autograd/profiler.h> 4 #include <torch/csrc/distributed/rpc/message.h> 5 #include <torch/csrc/distributed/rpc/rpc_agent.h> 6 #include <torch/csrc/distributed/rpc/rpc_command_base.h> 7 #include <torch/csrc/distributed/rpc/types.h> 8 9 namespace torch { 10 namespace distributed { 11 namespace autograd { 12 13 class TORCH_API RpcWithProfilingReq : public rpc::RpcCommandBase { 14 public: 15 // For sending RPCs, invoked when client is creating this RPC command. 16 RpcWithProfilingReq( 17 rpc::MessageType messageType, 18 c10::intrusive_ptr<rpc::Message> wrappedMessage, 19 torch::autograd::profiler::ProfilerConfig&& profilerConfig, 20 rpc::ProfilingId profilingKeyId); 21 22 // For receiving an RPC 23 // Used in fromMessage. 24 RpcWithProfilingReq( 25 rpc::MessageType messageType, 26 std::unique_ptr<rpc::RpcCommandBase> wrappedRpc, 27 rpc::MessageType wrappedMessageType, 28 std::vector<torch::Tensor> tensors, 29 torch::autograd::profiler::ProfilerConfig&& profilerConfig, 30 rpc::ProfilingId profilingKeyId); 31 32 // Convert this RPC Command to a Message that can be sent over the wire. 33 c10::intrusive_ptr<rpc::Message> toMessageImpl() && override; 34 static std::unique_ptr<RpcWithProfilingReq> fromMessage( 35 const rpc::Message& message); 36 37 // Retrieve the profiling data that is associated with this command. 38 torch::autograd::profiler::ProfilerConfig getProfilingConfig() const; 39 // Retrieve the globally unique profiling ID corresponding to this command. 40 const rpc::ProfilingId& getProfilingId() const; 41 // Retrieve the original RPC which this ProfilingRPC wraps. 42 RpcCommandBase& wrappedRpc(); 43 // Destructively move the wrapped RPC. 44 std::unique_ptr<RpcCommandBase> moveWrappedRpc() &&; 45 // Message type of the wrapped RPC 46 rpc::MessageType wrappedMessageType() const; 47 void setWrappedRpc(std::unique_ptr<RpcCommandBase> wrappedRpc); 48 49 private: 50 // message type 51 const rpc::MessageType messageType_; 52 // wrapped message 53 c10::intrusive_ptr<rpc::Message> wrappedMessage_; 54 std::unique_ptr<RpcCommandBase> wrappedRpc_; 55 rpc::MessageType wrappedMessageType_; 56 std::vector<torch::Tensor> tensors_; 57 const torch::autograd::profiler::ProfilerConfig profilerConfig_; 58 const rpc::ProfilingId profilingKeyId_; 59 }; 60 } // namespace autograd 61 } // namespace distributed 62 } // namespace torch 63