1 #pragma once 2 3 #include <torch/csrc/distributed/rpc/message.h> 4 #include <torch/csrc/distributed/rpc/rpc_command_base.h> 5 #include <torch/csrc/jit/runtime/operator.h> 6 #include <torch/csrc/jit/serialization/pickler.h> 7 #include <optional> 8 #include <vector> 9 10 namespace torch { 11 namespace distributed { 12 namespace rpc { 13 14 using torch::jit::Operator; 15 16 // A ScriptCall instance represents an invocation of a builtin operator for a 17 // TorchScript function. If it is a builtin operator, it 18 // contains a shared ptr to the `Operator` and a list of arguments. 19 // If it is a TorchScript function, it contains a non empty qualifiedName string 20 // to the TorchScript function schema name and a list of arguments. 21 class TORCH_API ScriptCall : public RpcCommandBase { 22 public: 23 // Constructor for builitin operator call. 24 ScriptCall(std::shared_ptr<Operator> op, std::vector<at::IValue>&& stack); 25 // Constructor for TorchScript function call. 26 ScriptCall( 27 const c10::QualifiedName& qualifiedName, 28 std::vector<at::IValue>&& stack, 29 const bool isAsyncExecution = false); 30 31 bool hasOp() const; 32 std::shared_ptr<Operator> op() const; 33 bool hasQualifiedName() const; 34 const c10::QualifiedName& qualifiedName() const; 35 // return the argument stack of this builtin operator 36 const std::vector<at::IValue>& stack() const; 37 std::vector<at::IValue>& stackRef(); isAsyncExecution()38 inline bool isAsyncExecution() const { 39 return isAsyncExecution_; 40 } 41 42 c10::intrusive_ptr<Message> toMessageImpl() && override; 43 static std::unique_ptr<ScriptCall> fromMessage(const Message& message); 44 45 ~ScriptCall() override = default; 46 47 protected: 48 virtual void toIValues(std::vector<at::IValue>& ivalues) const; 49 static std::unique_ptr<ScriptCall> fromIValues( 50 std::vector<at::IValue>& ivalues); 51 52 private: 53 // Given an operator symbol and a string schema, return the matched operator. 54 static std::shared_ptr<Operator> matchOperator(const std::string& str_schema); 55 56 static const std::string BUILTIN_OP_NAMESPACE_; 57 static const std::string ATEN_PREFIX_; 58 59 // This field has value if this ScriptCall represents invocation of a builtin 60 // operator. 61 std::optional<std::shared_ptr<Operator>> op_; 62 // This field has non empty string if this ScriptCall represents invocation of 63 // an annotated torchscript function defined by users. 64 std::optional<const c10::QualifiedName> qualifiedName_; 65 std::vector<at::IValue> stack_; 66 const bool isAsyncExecution_; 67 }; 68 69 } // namespace rpc 70 } // namespace distributed 71 } // namespace torch 72