xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/autograd/profiler.h>
4 #include <torch/csrc/distributed/rpc/message.h>
5 #include <torch/csrc/distributed/rpc/rpc_agent.h>
6 #include <torch/csrc/distributed/rpc/rpc_command_base.h>
7 #include <torch/csrc/distributed/rpc/types.h>
8 
9 namespace torch {
10 namespace distributed {
11 namespace autograd {
12 
13 class TORCH_API RpcWithProfilingReq : public rpc::RpcCommandBase {
14  public:
15   // For sending RPCs, invoked when client is creating this RPC command.
16   RpcWithProfilingReq(
17       rpc::MessageType messageType,
18       c10::intrusive_ptr<rpc::Message> wrappedMessage,
19       torch::autograd::profiler::ProfilerConfig&& profilerConfig,
20       rpc::ProfilingId profilingKeyId);
21 
22   // For receiving an RPC
23   // Used in fromMessage.
24   RpcWithProfilingReq(
25       rpc::MessageType messageType,
26       std::unique_ptr<rpc::RpcCommandBase> wrappedRpc,
27       rpc::MessageType wrappedMessageType,
28       std::vector<torch::Tensor> tensors,
29       torch::autograd::profiler::ProfilerConfig&& profilerConfig,
30       rpc::ProfilingId profilingKeyId);
31 
32   // Convert this RPC Command to a Message that can be sent over the wire.
33   c10::intrusive_ptr<rpc::Message> toMessageImpl() && override;
34   static std::unique_ptr<RpcWithProfilingReq> fromMessage(
35       const rpc::Message& message);
36 
37   // Retrieve the profiling data that is associated with this command.
38   torch::autograd::profiler::ProfilerConfig getProfilingConfig() const;
39   // Retrieve the globally unique profiling ID corresponding to this command.
40   const rpc::ProfilingId& getProfilingId() const;
41   // Retrieve the original RPC which this ProfilingRPC wraps.
42   RpcCommandBase& wrappedRpc();
43   // Destructively move the wrapped RPC.
44   std::unique_ptr<RpcCommandBase> moveWrappedRpc() &&;
45   // Message type of the wrapped RPC
46   rpc::MessageType wrappedMessageType() const;
47   void setWrappedRpc(std::unique_ptr<RpcCommandBase> wrappedRpc);
48 
49  private:
50   // message type
51   const rpc::MessageType messageType_;
52   // wrapped message
53   c10::intrusive_ptr<rpc::Message> wrappedMessage_;
54   std::unique_ptr<RpcCommandBase> wrappedRpc_;
55   rpc::MessageType wrappedMessageType_;
56   std::vector<torch::Tensor> tensors_;
57   const torch::autograd::profiler::ProfilerConfig profilerConfig_;
58   const rpc::ProfilingId profilingKeyId_;
59 };
60 } // namespace autograd
61 } // namespace distributed
62 } // namespace torch
63