1 #include <ATen/ATen.h>
2 #include <ATen/core/op_registration/op_registration.h>
3 #include <torch/csrc/distributed/autograd/autograd.h>
4 #include <torch/csrc/distributed/autograd/context/container.h>
5 #include <torch/csrc/distributed/autograd/engine/dist_engine.h>
6 #include <torch/csrc/distributed/rpc/rpc_agent.h>
7 #include <torch/csrc/distributed/rpc/rref_impl.h>
8 #include <torch/csrc/distributed/rpc/torchscript_functions.h>
9 #include <torch/csrc/jit/python/pybind_utils.h>
10 #include <torch/csrc/jit/runtime/register_ops_utils.h>
11 #include <torch/library.h>
12
13 #include <fmt/format.h>
14 #include <stdexcept>
15
16 namespace dist_autograd = torch::distributed::autograd;
17 namespace dist_rpc = torch::distributed::rpc;
18
19 namespace torch::jit {
20
21 namespace {
22 distributed::rpc::RegisterWorkerInfoOnce workerInfo{};
23
24 // prepare the rpc input arguments and call the C++ impls
prepare_and_call_rpc_op(Stack & stack,int num_inputs,const std::string & rpc_op)25 void prepare_and_call_rpc_op(
26 Stack& stack,
27 int num_inputs,
28 const std::string& rpc_op) {
29 // Get inputs from the stack.
30 auto stackIter = stack.end() - num_inputs;
31 auto& dstWorkerIValue = *stackIter++;
32 auto& qualifiedNameIValue = *stackIter++;
33 IValue emptyTuple(c10::ivalue::Tuple::create({}));
34 IValue emptyDict{c10::impl::GenericDict(AnyType::get(), AnyType::get())};
35 // Equivalent to Python statement
36 // `args = args if args is not None else ()`.
37 auto& argsTupleIValue = num_inputs >= 3 ? *stackIter++ : emptyTuple;
38 // `kwargs = kwargs if kwargs is not None else {}`.
39 auto& kwargsDictIValue = num_inputs >= 4 ? *stackIter++ : emptyDict;
40
41 // IValue corresponding to placeholder for RPC timeout. Used if no
42 // rpc timeout is specified by user.
43 IValue noTimeout(torch::distributed::rpc::kUnsetRpcTimeout);
44 const auto rpcMaxInputs = 5;
45 auto& timeoutIValue = num_inputs >= rpcMaxInputs ? *stackIter++ : noTimeout;
46 TORCH_INTERNAL_ASSERT(
47 dstWorkerIValue.isString() ||
48 c10::getCustomClassType<c10::intrusive_ptr<dist_rpc::WorkerInfo>>() ==
49 dstWorkerIValue.type());
50 TORCH_INTERNAL_ASSERT(qualifiedNameIValue.isString());
51 TORCH_INTERNAL_ASSERT(argsTupleIValue.isTuple());
52 TORCH_INTERNAL_ASSERT(kwargsDictIValue.isGenericDict());
53 TORCH_INTERNAL_ASSERT(timeoutIValue.isDouble());
54
55 // Get FunctionSchema for qualifiedName.
56 auto qualifiedName = c10::QualifiedName(qualifiedNameIValue.toStringRef());
57 std::shared_ptr<CompilationUnit> cuPtr;
58 {
59 py::gil_scoped_acquire acquire;
60 cuPtr = get_python_cu();
61 }
62 auto& functionSchema = cuPtr->get_function(qualifiedName).getSchema();
63
64 // Build the stack for the user callable.
65 // It's similar to
66 // Stack createStackForSchema(FunctionSchema, py::args,
67 // py::kwargs). Instead, it's Stack
68 // createStackForSchema(FunctionSchema, IValue<Tuple>,
69 // IValue<Dict>).
70 Stack userCallableStack;
71 userCallableStack.reserve(functionSchema.arguments().size());
72
73 // Move args from Tuple IValue to Stack.
74 for (auto& elem : argsTupleIValue.toTupleRef().elements()) {
75 push(userCallableStack, std::move(elem));
76 }
77
78 // Move kwargs from Dict IValue to Stack.
79 size_t consumed_kwargs = 0;
80 auto kwargsDict = kwargsDictIValue.toGenericDict();
81 for (size_t i = userCallableStack.size();
82 i < functionSchema.arguments().size();
83 ++i) {
84 const auto& arg = functionSchema.arguments()[i];
85 const auto& argName = arg.name();
86 if (kwargsDict.contains(argName)) {
87 push(userCallableStack, kwargsDict.at(argName));
88 consumed_kwargs += 1;
89 } else if (arg.default_value()) {
90 push(userCallableStack, *arg.default_value());
91 } else {
92 throw std::runtime_error(c10::str(
93 functionSchema.name(),
94 "() is missing value for argument '",
95 argName,
96 "'. Declaration: ",
97 functionSchema));
98 }
99 }
100 // Raise exception showing the unexpected kwargs.
101 if (consumed_kwargs != kwargsDict.size()) {
102 std::vector<std::string> names;
103 for (const auto& entry : kwargsDict) {
104 const IValue& keyIValue = entry.key();
105 const string& keyStr = keyIValue.toStringRef();
106 names.emplace_back(keyStr);
107 }
108 throw std::runtime_error(functionSchema.findErrorInKwargs(names));
109 }
110
111 // Get destination WorkerName.
112 std::string dstWorkerNameStr;
113 if (dstWorkerIValue.isString()) {
114 // ivalue::ConstantString::str_ is a const member, which can't be
115 // moved, copy it here.
116 dstWorkerNameStr = dstWorkerIValue.toStringRef();
117 } else {
118 dstWorkerNameStr =
119 dstWorkerIValue.toCustomClass<dist_rpc::WorkerInfo>()->name_;
120 }
121 // Get RPC timeout, if specified by user.
122 const auto rpcTimeout = timeoutIValue.toDouble();
123
124 if (rpc_op == "rpc_async") {
125 // Send RPC request.
126 auto futureIValuePtr = dist_rpc::rpcTorchscript(
127 dstWorkerNameStr,
128 qualifiedName,
129 functionSchema,
130 userCallableStack,
131 rpcTimeout);
132 // Push output to the stack.
133 drop(stack, num_inputs);
134 stack.emplace_back(std::move(futureIValuePtr));
135 } else if (rpc_op == "rpc_sync") {
136 // Send RPC request.
137 auto futureIValuePtr = dist_rpc::rpcTorchscript(
138 dstWorkerNameStr,
139 qualifiedName,
140 functionSchema,
141 userCallableStack,
142 rpcTimeout);
143 futureIValuePtr->wait();
144 if (futureIValuePtr->hasError()) {
145 // throw error if future hasError
146 throw std::runtime_error(futureIValuePtr->tryRetrieveErrorMessage());
147 } else {
148 auto res = futureIValuePtr->value();
149 // Push output to the stack.
150 drop(stack, num_inputs);
151 stack.emplace_back(std::move(res));
152 }
153 } else if (rpc_op == "rpc_remote") {
154 auto rrefPtr = dist_rpc::remoteTorchscript(
155 dstWorkerNameStr,
156 qualifiedName,
157 functionSchema,
158 userCallableStack,
159 rpcTimeout);
160 // Push output to the stack.
161 drop(stack, num_inputs);
162 stack.emplace_back(
163 c10::static_intrusive_pointer_cast<c10::RRefInterface>(rrefPtr));
164 } else {
165 throw std::runtime_error(
166 c10::str(rpc_op, "() is not supported in TorchScript!'"));
167 }
168 }
169
170 RegisterOperators reg_rpc_ops(
171 {Operator(
172 fmt::format(
173 "aten::to_here(RRef(t) self, float timeout = {}) -> t(*)",
174 torch::distributed::rpc::kDefaultRpcTimeoutSeconds),
__anon84498b7a0202(Stack& stack) 175 [](Stack& stack) {
176 auto timeout = pop(stack).toDouble();
177 auto rref = pop(stack).toRRef();
178 IValue res;
179 if (rref->isOwner()) {
180 res =
181 c10::dynamic_intrusive_pointer_cast<dist_rpc::OwnerRRef>(rref)
182 ->getValue();
183 } else {
184 res = c10::dynamic_intrusive_pointer_cast<dist_rpc::UserRRef>(rref)
185 ->toHere(timeout);
186 }
187 push(stack, std::move(res));
188 },
189 aliasAnalysisFromSchema()),
190 Operator(
191 "aten::local_value(RRef(t) self) -> t(*)",
__anon84498b7a0302(Stack& stack) 192 [](Stack& stack) {
193 auto rref = pop(stack).toRRef();
194 TORCH_CHECK(
195 rref->isOwner(),
196 "Can't call RRef.local_value() on a non-owner RRef.");
197 IValue res =
198 c10::static_intrusive_pointer_cast<dist_rpc::OwnerRRef>(rref)
199 ->getValue();
200 push(stack, std::move(res));
201 },
202 aliasAnalysisFromSchema()),
203 Operator(
204 "aten::is_owner(RRef(t) self) -> bool",
__anon84498b7a0402(Stack& stack) 205 [](Stack& stack) {
206 auto rref = pop(stack).toRRef();
207 push(stack, rref->isOwner());
208 },
209 aliasAnalysisFromSchema()),
210 Operator(
211 "aten::owner(RRef(t) self) -> __torch__.torch.classes.dist_rpc.WorkerInfo",
__anon84498b7a0502(Stack& stack) 212 [](Stack& stack) {
213 auto rref = pop(stack).toRRef();
214 push(
215 stack,
216 torch::make_custom_class<distributed::rpc::WorkerInfo>(
217 rref->ownerName(), rref->owner()));
218 },
219 aliasAnalysisFromSchema()),
220 Operator(
221 "aten::owner_name(RRef(t) self) -> str",
__anon84498b7a0602(Stack& stack) 222 [](Stack& stack) {
223 auto rref = pop(stack).toRRef();
224 push(stack, rref->ownerName());
225 },
226 aliasAnalysisFromSchema()),
227 Operator(
228 "aten::confirmed_by_owner(RRef(t) self) -> bool",
__anon84498b7a0702(Stack& stack) 229 [](Stack& stack) {
230 auto rref = pop(stack).toRRef();
231 push(stack, rref->confirmedByOwner());
232 },
233 aliasAnalysisFromSchema()),
234 Operator(
235 "aten::dist_backward(int context_id, Tensor[] roots, bool retain_graph=False) -> ()",
__anon84498b7a0802(Stack& stack) 236 [](Stack& stack) {
237 bool retain_graph = pop(stack).toBool();
238 auto roots_list = pop(stack).toTensorList();
239 int64_t context_id = pop(stack).toInt();
240 torch::autograd::variable_list roots(
241 roots_list.begin(), roots_list.end());
242 dist_autograd::backward(context_id, roots, retain_graph);
243 },
244 aliasAnalysisConservative()),
245 Operator(
246 prim::rpc_sync,
__anon84498b7a0902(const Node* node) 247 [](const Node* node) -> Operation {
248 int num_inputs = node->inputs().size();
249 return [num_inputs](Stack& stack) {
250 prepare_and_call_rpc_op(stack, num_inputs, "rpc_sync");
251 };
252 },
253 aliasAnalysisSpecialCase()),
254 Operator(
255 prim::rpc_remote,
__anon84498b7a0b02(const Node* node) 256 [](const Node* node) -> Operation {
257 int num_inputs = node->inputs().size();
258 return [num_inputs](Stack& stack) {
259 prepare_and_call_rpc_op(stack, num_inputs, "rpc_remote");
260 };
261 },
262 aliasAnalysisSpecialCase()),
263 Operator(
264 prim::rpc_async,
__anon84498b7a0d02(const Node* node) 265 [](const Node* node) -> Operation {
266 int num_inputs = node->inputs().size();
267 return [num_inputs](Stack& stack) {
268 prepare_and_call_rpc_op(stack, num_inputs, "rpc_async");
269 };
270 },
271 aliasAnalysisSpecialCase())});
272
273 // Implementations located in
274 // torch/csrc/jit/runtime/register_distributed_ops.cpp
TORCH_LIBRARY_IMPL(aten,CatchAll,m)275 TORCH_LIBRARY_IMPL(aten, CatchAll, m) {
276 m.impl("get_gradients", [](int64_t context_id) {
277 const auto& autogradContext =
278 dist_autograd::DistAutogradContainer::getInstance().retrieveContext(
279 context_id);
280 return autogradContext->getGradients();
281 });
282 }
283
284 } // namespace
285 } // namespace torch::jit
286