xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/script_remote_call.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/distributed/rpc/script_call.h>
4 #include <torch/csrc/distributed/rpc/types.h>
5 #include <torch/csrc/jit/runtime/operator.h>
6 #include <torch/csrc/jit/serialization/pickler.h>
7 #include <vector>
8 
9 namespace torch {
10 namespace distributed {
11 namespace rpc {
12 
13 using torch::jit::Operator;
14 
15 // A ScriptRemoteCall instance represents an invocation of `dist.remote` on a
16 // builtin operator. Currently, it does not support using RRef as arguments yet.
17 // Besides the operator and a vector of arguments, ScriptRemoteCall also
18 // contains the RRefId and the ForkId of the return value RRef.
19 class TORCH_API ScriptRemoteCall final : public ScriptCall {
20  public:
21   // Constructor for builitin operator call.
22   ScriptRemoteCall(
23       std::shared_ptr<Operator> op,
24       std::vector<at::IValue>&& stack,
25       const RRefId& retRRefId,
26       const ForkId& retForkId);
27 
28   // Constructor for TorchScript function call.
29   ScriptRemoteCall(
30       const c10::QualifiedName& qualifiedName,
31       std::vector<at::IValue>&& stack,
32       const RRefId& retRRefId,
33       const ForkId& retForkId,
34       const bool isAsyncExecution);
35 
retRRefId()36   inline const RRefId& retRRefId() const {
37     return retRRefId_;
38   }
39 
retForkId()40   inline const ForkId& retForkId() const {
41     return retForkId_;
42   }
43 
44   static std::unique_ptr<ScriptRemoteCall> fromIValues(
45       std::vector<at::IValue>& ivalues);
46 
47   c10::intrusive_ptr<Message> toMessageImpl() && override;
48   static std::unique_ptr<ScriptRemoteCall> fromMessage(const Message& message);
49 
50  private:
51   const RRefId retRRefId_;
52   const ForkId retForkId_;
53 };
54 
55 } // namespace rpc
56 } // namespace distributed
57 } // namespace torch
58