xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/register_distributed_ops.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/core/op_registration/op_registration.h>
3 #include <torch/csrc/distributed/autograd/autograd.h>
4 #include <torch/csrc/distributed/autograd/context/container.h>
5 #include <torch/csrc/distributed/autograd/engine/dist_engine.h>
6 #include <torch/csrc/distributed/rpc/rpc_agent.h>
7 #include <torch/csrc/distributed/rpc/rref_impl.h>
8 #include <torch/csrc/distributed/rpc/torchscript_functions.h>
9 #include <torch/csrc/jit/python/pybind_utils.h>
10 #include <torch/csrc/jit/runtime/register_ops_utils.h>
11 #include <torch/library.h>
12 
13 #include <fmt/format.h>
14 #include <stdexcept>
15 
16 namespace dist_autograd = torch::distributed::autograd;
17 namespace dist_rpc = torch::distributed::rpc;
18 
19 namespace torch::jit {
20 
21 namespace {
22 distributed::rpc::RegisterWorkerInfoOnce workerInfo{};
23 
24 // prepare the rpc input arguments and call the C++ impls
prepare_and_call_rpc_op(Stack & stack,int num_inputs,const std::string & rpc_op)25 void prepare_and_call_rpc_op(
26     Stack& stack,
27     int num_inputs,
28     const std::string& rpc_op) {
29   // Get inputs from the stack.
30   auto stackIter = stack.end() - num_inputs;
31   auto& dstWorkerIValue = *stackIter++;
32   auto& qualifiedNameIValue = *stackIter++;
33   IValue emptyTuple(c10::ivalue::Tuple::create({}));
34   IValue emptyDict{c10::impl::GenericDict(AnyType::get(), AnyType::get())};
35   // Equivalent to Python statement
36   // `args = args if args is not None else ()`.
37   auto& argsTupleIValue = num_inputs >= 3 ? *stackIter++ : emptyTuple;
38   // `kwargs = kwargs if kwargs is not None else {}`.
39   auto& kwargsDictIValue = num_inputs >= 4 ? *stackIter++ : emptyDict;
40 
41   // IValue corresponding to placeholder for RPC timeout. Used if no
42   // rpc timeout is specified by user.
43   IValue noTimeout(torch::distributed::rpc::kUnsetRpcTimeout);
44   const auto rpcMaxInputs = 5;
45   auto& timeoutIValue = num_inputs >= rpcMaxInputs ? *stackIter++ : noTimeout;
46   TORCH_INTERNAL_ASSERT(
47       dstWorkerIValue.isString() ||
48       c10::getCustomClassType<c10::intrusive_ptr<dist_rpc::WorkerInfo>>() ==
49           dstWorkerIValue.type());
50   TORCH_INTERNAL_ASSERT(qualifiedNameIValue.isString());
51   TORCH_INTERNAL_ASSERT(argsTupleIValue.isTuple());
52   TORCH_INTERNAL_ASSERT(kwargsDictIValue.isGenericDict());
53   TORCH_INTERNAL_ASSERT(timeoutIValue.isDouble());
54 
55   // Get FunctionSchema for qualifiedName.
56   auto qualifiedName = c10::QualifiedName(qualifiedNameIValue.toStringRef());
57   std::shared_ptr<CompilationUnit> cuPtr;
58   {
59     py::gil_scoped_acquire acquire;
60     cuPtr = get_python_cu();
61   }
62   auto& functionSchema = cuPtr->get_function(qualifiedName).getSchema();
63 
64   // Build the stack for the user callable.
65   // It's similar to
66   // Stack createStackForSchema(FunctionSchema, py::args,
67   // py::kwargs). Instead, it's Stack
68   // createStackForSchema(FunctionSchema, IValue<Tuple>,
69   // IValue<Dict>).
70   Stack userCallableStack;
71   userCallableStack.reserve(functionSchema.arguments().size());
72 
73   // Move args from Tuple IValue to Stack.
74   for (auto& elem : argsTupleIValue.toTupleRef().elements()) {
75     push(userCallableStack, std::move(elem));
76   }
77 
78   // Move kwargs from Dict IValue to Stack.
79   size_t consumed_kwargs = 0;
80   auto kwargsDict = kwargsDictIValue.toGenericDict();
81   for (size_t i = userCallableStack.size();
82        i < functionSchema.arguments().size();
83        ++i) {
84     const auto& arg = functionSchema.arguments()[i];
85     const auto& argName = arg.name();
86     if (kwargsDict.contains(argName)) {
87       push(userCallableStack, kwargsDict.at(argName));
88       consumed_kwargs += 1;
89     } else if (arg.default_value()) {
90       push(userCallableStack, *arg.default_value());
91     } else {
92       throw std::runtime_error(c10::str(
93           functionSchema.name(),
94           "() is missing value for argument '",
95           argName,
96           "'. Declaration: ",
97           functionSchema));
98     }
99   }
100   // Raise exception showing the unexpected kwargs.
101   if (consumed_kwargs != kwargsDict.size()) {
102     std::vector<std::string> names;
103     for (const auto& entry : kwargsDict) {
104       const IValue& keyIValue = entry.key();
105       const string& keyStr = keyIValue.toStringRef();
106       names.emplace_back(keyStr);
107     }
108     throw std::runtime_error(functionSchema.findErrorInKwargs(names));
109   }
110 
111   // Get destination WorkerName.
112   std::string dstWorkerNameStr;
113   if (dstWorkerIValue.isString()) {
114     // ivalue::ConstantString::str_ is a const member, which can't be
115     // moved, copy it here.
116     dstWorkerNameStr = dstWorkerIValue.toStringRef();
117   } else {
118     dstWorkerNameStr =
119         dstWorkerIValue.toCustomClass<dist_rpc::WorkerInfo>()->name_;
120   }
121   // Get RPC timeout, if specified by user.
122   const auto rpcTimeout = timeoutIValue.toDouble();
123 
124   if (rpc_op == "rpc_async") {
125     // Send RPC request.
126     auto futureIValuePtr = dist_rpc::rpcTorchscript(
127         dstWorkerNameStr,
128         qualifiedName,
129         functionSchema,
130         userCallableStack,
131         rpcTimeout);
132     // Push output to the stack.
133     drop(stack, num_inputs);
134     stack.emplace_back(std::move(futureIValuePtr));
135   } else if (rpc_op == "rpc_sync") {
136     // Send RPC request.
137     auto futureIValuePtr = dist_rpc::rpcTorchscript(
138         dstWorkerNameStr,
139         qualifiedName,
140         functionSchema,
141         userCallableStack,
142         rpcTimeout);
143     futureIValuePtr->wait();
144     if (futureIValuePtr->hasError()) {
145       // throw error if future hasError
146       throw std::runtime_error(futureIValuePtr->tryRetrieveErrorMessage());
147     } else {
148       auto res = futureIValuePtr->value();
149       // Push output to the stack.
150       drop(stack, num_inputs);
151       stack.emplace_back(std::move(res));
152     }
153   } else if (rpc_op == "rpc_remote") {
154     auto rrefPtr = dist_rpc::remoteTorchscript(
155         dstWorkerNameStr,
156         qualifiedName,
157         functionSchema,
158         userCallableStack,
159         rpcTimeout);
160     // Push output to the stack.
161     drop(stack, num_inputs);
162     stack.emplace_back(
163         c10::static_intrusive_pointer_cast<c10::RRefInterface>(rrefPtr));
164   } else {
165     throw std::runtime_error(
166         c10::str(rpc_op, "() is not supported in TorchScript!'"));
167   }
168 }
169 
170 RegisterOperators reg_rpc_ops(
171     {Operator(
172          fmt::format(
173              "aten::to_here(RRef(t) self, float timeout = {}) -> t(*)",
174              torch::distributed::rpc::kDefaultRpcTimeoutSeconds),
__anon84498b7a0202(Stack& stack) 175          [](Stack& stack) {
176            auto timeout = pop(stack).toDouble();
177            auto rref = pop(stack).toRRef();
178            IValue res;
179            if (rref->isOwner()) {
180              res =
181                  c10::dynamic_intrusive_pointer_cast<dist_rpc::OwnerRRef>(rref)
182                      ->getValue();
183            } else {
184              res = c10::dynamic_intrusive_pointer_cast<dist_rpc::UserRRef>(rref)
185                        ->toHere(timeout);
186            }
187            push(stack, std::move(res));
188          },
189          aliasAnalysisFromSchema()),
190      Operator(
191          "aten::local_value(RRef(t) self) -> t(*)",
__anon84498b7a0302(Stack& stack) 192          [](Stack& stack) {
193            auto rref = pop(stack).toRRef();
194            TORCH_CHECK(
195                rref->isOwner(),
196                "Can't call RRef.local_value() on a non-owner RRef.");
197            IValue res =
198                c10::static_intrusive_pointer_cast<dist_rpc::OwnerRRef>(rref)
199                    ->getValue();
200            push(stack, std::move(res));
201          },
202          aliasAnalysisFromSchema()),
203      Operator(
204          "aten::is_owner(RRef(t) self) -> bool",
__anon84498b7a0402(Stack& stack) 205          [](Stack& stack) {
206            auto rref = pop(stack).toRRef();
207            push(stack, rref->isOwner());
208          },
209          aliasAnalysisFromSchema()),
210      Operator(
211          "aten::owner(RRef(t) self) -> __torch__.torch.classes.dist_rpc.WorkerInfo",
__anon84498b7a0502(Stack& stack) 212          [](Stack& stack) {
213            auto rref = pop(stack).toRRef();
214            push(
215                stack,
216                torch::make_custom_class<distributed::rpc::WorkerInfo>(
217                    rref->ownerName(), rref->owner()));
218          },
219          aliasAnalysisFromSchema()),
220      Operator(
221          "aten::owner_name(RRef(t) self) -> str",
__anon84498b7a0602(Stack& stack) 222          [](Stack& stack) {
223            auto rref = pop(stack).toRRef();
224            push(stack, rref->ownerName());
225          },
226          aliasAnalysisFromSchema()),
227      Operator(
228          "aten::confirmed_by_owner(RRef(t) self) -> bool",
__anon84498b7a0702(Stack& stack) 229          [](Stack& stack) {
230            auto rref = pop(stack).toRRef();
231            push(stack, rref->confirmedByOwner());
232          },
233          aliasAnalysisFromSchema()),
234      Operator(
235          "aten::dist_backward(int context_id, Tensor[] roots, bool retain_graph=False) -> ()",
__anon84498b7a0802(Stack& stack) 236          [](Stack& stack) {
237            bool retain_graph = pop(stack).toBool();
238            auto roots_list = pop(stack).toTensorList();
239            int64_t context_id = pop(stack).toInt();
240            torch::autograd::variable_list roots(
241                roots_list.begin(), roots_list.end());
242            dist_autograd::backward(context_id, roots, retain_graph);
243          },
244          aliasAnalysisConservative()),
245      Operator(
246          prim::rpc_sync,
__anon84498b7a0902(const Node* node) 247          [](const Node* node) -> Operation {
248            int num_inputs = node->inputs().size();
249            return [num_inputs](Stack& stack) {
250              prepare_and_call_rpc_op(stack, num_inputs, "rpc_sync");
251            };
252          },
253          aliasAnalysisSpecialCase()),
254      Operator(
255          prim::rpc_remote,
__anon84498b7a0b02(const Node* node) 256          [](const Node* node) -> Operation {
257            int num_inputs = node->inputs().size();
258            return [num_inputs](Stack& stack) {
259              prepare_and_call_rpc_op(stack, num_inputs, "rpc_remote");
260            };
261          },
262          aliasAnalysisSpecialCase()),
263      Operator(
264          prim::rpc_async,
__anon84498b7a0d02(const Node* node) 265          [](const Node* node) -> Operation {
266            int num_inputs = node->inputs().size();
267            return [num_inputs](Stack& stack) {
268              prepare_and_call_rpc_op(stack, num_inputs, "rpc_async");
269            };
270          },
271          aliasAnalysisSpecialCase())});
272 
273 // Implementations located in
274 // torch/csrc/jit/runtime/register_distributed_ops.cpp
TORCH_LIBRARY_IMPL(aten,CatchAll,m)275 TORCH_LIBRARY_IMPL(aten, CatchAll, m) {
276   m.impl("get_gradients", [](int64_t context_id) {
277     const auto& autogradContext =
278         dist_autograd::DistAutogradContainer::getInstance().retrieveContext(
279             context_id);
280     return autogradContext->getGradients();
281   });
282 }
283 
284 } // namespace
285 } // namespace torch::jit
286