xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/autograd/profiler.h>
4 #include <torch/csrc/distributed/rpc/message.h>
5 #include <torch/csrc/distributed/rpc/rpc_agent.h>
6 #include <torch/csrc/distributed/rpc/rpc_command_base.h>
7 #include <torch/csrc/distributed/rpc/types.h>
8 
9 namespace torch {
10 namespace distributed {
11 namespace autograd {
12 class TORCH_API RpcWithProfilingResp : public rpc::RpcCommandBase {
13  public:
14   // For sending RPCs over the wire
15   RpcWithProfilingResp(
16       rpc::MessageType messageType,
17       c10::intrusive_ptr<rpc::Message> wrappedMessage,
18       std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents,
19       rpc::ProfilingId profilingId);
20 
21   // For receiving RPCs. Used in from message when converting a message received
22   // over the wire.
23   RpcWithProfilingResp(
24       rpc::MessageType messageType,
25       std::unique_ptr<rpc::RpcCommandBase> wrappedRpc,
26       rpc::MessageType wrappedMessageType,
27       std::vector<torch::Tensor> tensors,
28       std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents,
29       rpc::ProfilingId profilingId);
30   c10::intrusive_ptr<rpc::Message> toMessageImpl() && override;
31   static std::unique_ptr<RpcWithProfilingResp> fromMessage(
32       const rpc::Message& message);
33   // Retrieve remote Events
34   std::vector<torch::autograd::profiler::LegacyEvent> getProfiledEvents() const;
35   // Retrieve the globally unique profiling ID corresponding to this command.
36   const rpc::ProfilingId& getProfilingId() const;
37   // Retrieve the original RPC which this ProfilingRPC wraps.
38   RpcCommandBase& wrappedRpc();
39   // Destructively move the wrapped RPC.
40   std::unique_ptr<RpcCommandBase> moveWrappedRpc() &&;
41   // Message type of the wrapped RPC
42   rpc::MessageType wrappedMessageType() const;
43   // Set the wrapped RPC for this RPC.
44   void setWrappedRpc(std::unique_ptr<RpcCommandBase> wrappedRpc);
45 
46  private:
47   // message type
48   const rpc::MessageType messageType_;
49   // wrapped message
50   c10::intrusive_ptr<rpc::Message> wrappedMessage_;
51   std::unique_ptr<RpcCommandBase> wrappedRpc_;
52   rpc::MessageType wrappedMessageType_;
53   std::vector<torch::Tensor> tensors_;
54   const std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents_;
55   const rpc::ProfilingId profilingId_;
56 };
57 } // namespace autograd
58 } // namespace distributed
59 } // namespace torch
60