1 #pragma once 2 3 #include <torch/csrc/distributed/rpc/script_call.h> 4 #include <torch/csrc/distributed/rpc/types.h> 5 #include <torch/csrc/jit/runtime/operator.h> 6 #include <torch/csrc/jit/serialization/pickler.h> 7 #include <vector> 8 9 namespace torch { 10 namespace distributed { 11 namespace rpc { 12 13 using torch::jit::Operator; 14 15 // A ScriptRemoteCall instance represents an invocation of `dist.remote` on a 16 // builtin operator. Currently, it does not support using RRef as arguments yet. 17 // Besides the operator and a vector of arguments, ScriptRemoteCall also 18 // contains the RRefId and the ForkId of the return value RRef. 19 class TORCH_API ScriptRemoteCall final : public ScriptCall { 20 public: 21 // Constructor for builitin operator call. 22 ScriptRemoteCall( 23 std::shared_ptr<Operator> op, 24 std::vector<at::IValue>&& stack, 25 const RRefId& retRRefId, 26 const ForkId& retForkId); 27 28 // Constructor for TorchScript function call. 29 ScriptRemoteCall( 30 const c10::QualifiedName& qualifiedName, 31 std::vector<at::IValue>&& stack, 32 const RRefId& retRRefId, 33 const ForkId& retForkId, 34 const bool isAsyncExecution); 35 retRRefId()36 inline const RRefId& retRRefId() const { 37 return retRRefId_; 38 } 39 retForkId()40 inline const ForkId& retForkId() const { 41 return retForkId_; 42 } 43 44 static std::unique_ptr<ScriptRemoteCall> fromIValues( 45 std::vector<at::IValue>& ivalues); 46 47 c10::intrusive_ptr<Message> toMessageImpl() && override; 48 static std::unique_ptr<ScriptRemoteCall> fromMessage(const Message& message); 49 50 private: 51 const RRefId retRRefId_; 52 const ForkId retForkId_; 53 }; 54 55 } // namespace rpc 56 } // namespace distributed 57 } // namespace torch 58