xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/script_call.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/distributed/rpc/message.h>
4 #include <torch/csrc/distributed/rpc/rpc_command_base.h>
5 #include <torch/csrc/jit/runtime/operator.h>
6 #include <torch/csrc/jit/serialization/pickler.h>
7 #include <optional>
8 #include <vector>
9 
10 namespace torch {
11 namespace distributed {
12 namespace rpc {
13 
14 using torch::jit::Operator;
15 
16 // A ScriptCall instance represents an invocation of a builtin operator for a
17 // TorchScript function. If it is a builtin operator, it
18 // contains a shared ptr to the `Operator` and a list of arguments.
19 // If it is a TorchScript function, it contains a non empty qualifiedName string
20 // to the TorchScript function schema name and a list of arguments.
21 class TORCH_API ScriptCall : public RpcCommandBase {
22  public:
23   // Constructor for builitin operator call.
24   ScriptCall(std::shared_ptr<Operator> op, std::vector<at::IValue>&& stack);
25   // Constructor for TorchScript function call.
26   ScriptCall(
27       const c10::QualifiedName& qualifiedName,
28       std::vector<at::IValue>&& stack,
29       const bool isAsyncExecution = false);
30 
31   bool hasOp() const;
32   std::shared_ptr<Operator> op() const;
33   bool hasQualifiedName() const;
34   const c10::QualifiedName& qualifiedName() const;
35   // return the argument stack of this builtin operator
36   const std::vector<at::IValue>& stack() const;
37   std::vector<at::IValue>& stackRef();
isAsyncExecution()38   inline bool isAsyncExecution() const {
39     return isAsyncExecution_;
40   }
41 
42   c10::intrusive_ptr<Message> toMessageImpl() && override;
43   static std::unique_ptr<ScriptCall> fromMessage(const Message& message);
44 
45   ~ScriptCall() override = default;
46 
47  protected:
48   virtual void toIValues(std::vector<at::IValue>& ivalues) const;
49   static std::unique_ptr<ScriptCall> fromIValues(
50       std::vector<at::IValue>& ivalues);
51 
52  private:
53   // Given an operator symbol and a string schema, return the matched operator.
54   static std::shared_ptr<Operator> matchOperator(const std::string& str_schema);
55 
56   static const std::string BUILTIN_OP_NAMESPACE_;
57   static const std::string ATEN_PREFIX_;
58 
59   // This field has value if this ScriptCall represents invocation of a builtin
60   // operator.
61   std::optional<std::shared_ptr<Operator>> op_;
62   // This field has non empty string if this ScriptCall represents invocation of
63   // an annotated torchscript function defined by users.
64   std::optional<const c10::QualifiedName> qualifiedName_;
65   std::vector<at::IValue> stack_;
66   const bool isAsyncExecution_;
67 };
68 
69 } // namespace rpc
70 } // namespace distributed
71 } // namespace torch
72