xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/script_call.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/rpc/rpc_agent.h>
2 #include <torch/csrc/distributed/rpc/script_call.h>
3 #include <torch/csrc/jit/serialization/pickle.h>
4 
5 namespace torch::distributed::rpc {
6 
7 const std::string ScriptCall::BUILTIN_OP_NAMESPACE_("torch.ops.aten.");
8 const std::string ScriptCall::ATEN_PREFIX_("aten::");
9 
ScriptCall(std::shared_ptr<Operator> op,std::vector<at::IValue> && stack)10 ScriptCall::ScriptCall(
11     std::shared_ptr<Operator> op,
12     std::vector<at::IValue>&& stack)
13     : op_(std::move(op)), stack_(stack), isAsyncExecution_(false) {}
14 
ScriptCall(const c10::QualifiedName & qualifiedName,std::vector<at::IValue> && stack,const bool isAsyncExecution)15 ScriptCall::ScriptCall(
16     const c10::QualifiedName& qualifiedName,
17     std::vector<at::IValue>&& stack,
18     const bool isAsyncExecution)
19     : qualifiedName_(qualifiedName),
20       stack_(stack),
21       isAsyncExecution_(isAsyncExecution) {}
22 
hasOp() const23 bool ScriptCall::hasOp() const {
24   return op_ ? true : false;
25 }
26 
op() const27 std::shared_ptr<Operator> ScriptCall::op() const {
28   return *op_;
29 }
30 
hasQualifiedName() const31 bool ScriptCall::hasQualifiedName() const {
32   return qualifiedName_ ? true : false;
33 }
34 
qualifiedName() const35 const c10::QualifiedName& ScriptCall::qualifiedName() const {
36   return *qualifiedName_;
37 }
38 
stack() const39 const std::vector<at::IValue>& ScriptCall::stack() const {
40   return stack_;
41 }
42 
stackRef()43 std::vector<at::IValue>& ScriptCall::stackRef() {
44   return stack_;
45 }
46 
toIValues(std::vector<at::IValue> & ivalues) const47 void ScriptCall::toIValues(std::vector<at::IValue>& ivalues) const {
48   for (auto& value : stack_) {
49     ivalues.push_back(value);
50   }
51 
52   if (hasOp()) {
53     TORCH_CHECK(
54         !hasQualifiedName(),
55         "It is builtin operator call, qualifiedName_ should not be set.");
56     // TODO: replace this with a real overload_name when FunctionSchema supports
57     // that.
58     ivalues.emplace_back(toString((*op_)->schema()));
59     // insert qualified name
60     auto opName = (*op_)->schema().name();
61     TORCH_CHECK(
62         opName.find("::") == opName.rfind("::") &&
63             opName.rfind(ATEN_PREFIX_) == 0,
64         "Unexpected operator name ",
65         opName);
66     // aten::add -> torch.ops.aten.add
67     opName.replace(0, ATEN_PREFIX_.length(), BUILTIN_OP_NAMESPACE_);
68     ivalues.emplace_back(std::move(opName));
69   } else if (hasQualifiedName()) {
70     ivalues.emplace_back(isAsyncExecution());
71     TORCH_CHECK(
72         !hasOp(),
73         "It is TorchScript function call, operator should not be set.");
74     ivalues.emplace_back((*qualifiedName_).qualifiedName());
75   } else {
76     TORCH_INTERNAL_ASSERT(
77         false,
78         "Either builtin operator or TorchScript function name should be set.");
79   }
80 }
81 
fromIValues(std::vector<at::IValue> & ivalues)82 std::unique_ptr<ScriptCall> ScriptCall::fromIValues(
83     std::vector<at::IValue>& ivalues) {
84   TORCH_INTERNAL_ASSERT(
85       ivalues.size() > 1,
86       "At least 2 IValues are required to build a ScriptCall.");
87 
88   // Last element in the vector is always qualifiedName for both
89   // builitin operator and TorchScript function
90   // If the qualifiedName is not a builtin operator name, then treat it
91   // as TorchScript function name
92   const std::string& qualifiedName = ivalues.back().toStringRef();
93 
94   if (qualifiedName.rfind(BUILTIN_OP_NAMESPACE_) == 0) {
95     ivalues.pop_back();
96     const std::string& str_schema = ivalues.back().toStringRef();
97     auto op = matchOperator(str_schema);
98 
99     ivalues.pop_back();
100     // remove str_schema from ivalues
101     return std::make_unique<ScriptCall>(op, std::move(ivalues));
102   } else {
103     ivalues.pop_back();
104     bool isAsyncExecution = ivalues.back().toBool();
105     ivalues.pop_back();
106     return std::make_unique<ScriptCall>(
107         c10::QualifiedName(qualifiedName),
108         std::move(ivalues),
109         isAsyncExecution);
110   }
111 }
112 
toMessageImpl()113 c10::intrusive_ptr<Message> ScriptCall::toMessageImpl() && {
114   std::vector<IValue> ivalues;
115   toIValues(ivalues);
116 
117   std::vector<torch::Tensor> tensor_table;
118   auto payload = jit::pickle(
119       c10::ivalue::Tuple::create(std::move(ivalues)), &tensor_table);
120 
121   return c10::make_intrusive<Message>(
122       std::move(payload), std::move(tensor_table), MessageType::SCRIPT_CALL);
123 }
124 
fromMessage(const Message & message)125 std::unique_ptr<ScriptCall> ScriptCall::fromMessage(const Message& message) {
126   auto payload = static_cast<const char*>(message.payload().data());
127   auto payload_size = message.payload().size();
128   auto value = jit::unpickle(
129       payload,
130       payload_size,
131       *RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
132       message.tensors());
133 
134   auto values = value.toTupleRef().elements().vec();
135   return fromIValues(values);
136 }
137 
matchOperator(const std::string & str_schema)138 std::shared_ptr<Operator> ScriptCall::matchOperator(
139     const std::string& str_schema) {
140   // TODO: This is a temporary solution. We should pass enough information to
141   // allow deterministically matched to one operator.
142 
143   // extract symbol from the schema
144   auto schema = torch::jit::parseSchema(str_schema);
145   auto symbol = at::Symbol::fromQualString(schema.name());
146 
147   for (auto op : torch::jit::getAllOperatorsFor(symbol)) {
148     if (toString(op->schema()) == str_schema) {
149       return op;
150     }
151   }
152 
153   TORCH_CHECK(false, "Cannot find matching operator for schema ", str_schema);
154 }
155 
156 } // namespace torch::distributed::rpc
157